package com.xebialabs.platform.sso.crypto;

import java.io.IOException;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.xebialabs.xlplatform.utils.Strings;

import static java.nio.charset.StandardCharsets.UTF_8;
import static com.google.common.base.Strings.emptyToNull;

public class KeyRetriever {

    private Logger logger = LoggerFactory.getLogger(KeyRetriever.class);

    private String jwks_uri;

    private Map<String, PublicKey> keysById = new LinkedHashMap<>();

    public KeyRetriever(String jwks_uri) {
        this.jwks_uri = jwks_uri;
        refreshKeys();
    }

    public PublicKey getKeyById(String keyId) {
        return keysById.get(keyId);
    }

    public void refreshKeys() {
        try {
            logger.info("Refreshing OIDC keys from {}...", jwks_uri);
            Map<String, PublicKey> result = new LinkedHashMap<>();
            String keySetJson = call(jwks_uri);
            JSONArray keys = new JSONObject(keySetJson).getJSONArray("keys");
            for (int i = 0; i < keys.length(); i++) {
                JSONObject key = keys.getJSONObject(i);
                String use = emptyToNull(key.optString("use"));
                String alg = emptyToNull(key.optString("alg"));
                if ((use == null || use.equals("sig")) && (alg == null || alg.equals("RS256"))) {
                    result.put(key.getString("kid"), computeKey(key.getString("e"), key.getString("n")));
                } else {
                    logger.debug("Retrieved key with id {} should have use/alg sig/RS256 but was {}/{}",
                            key.getString("kid"), use, alg);
                }
            }
            keysById = result;
            logCurrentKeys();
            if (keysById.size() == 0) {
                logger.warn(keys.length() == 0
                        ? "Server returned zero keys - OIDC login effectively disabled"
                        : ("OIDC provider returned " + keys.length() + " keys but none were usable"));
            }
        } catch (Exception e) {
            logger.warn("Could not properly retrieve keys from " + jwks_uri, e);
        }
    }

    private void logCurrentKeys() {
        logger.info("Currently known OIDC key IDs:\n * " +
                Strings.mkString(new ArrayList<>(keysById.keySet()), "\n * ") +
                (keysById.size() == 0 ? "(none)" : ""));
    }

    private PublicKey computeKey(final String e, final String n) throws InvalidKeySpecException, NoSuchAlgorithmException {
        BigInteger exponent = new BigInteger(1, Base64.getUrlDecoder().decode(e));
        BigInteger modulus = new BigInteger(1, Base64.getUrlDecoder().decode(n));
        final RSAPublicKeySpec keySpec = new RSAPublicKeySpec(modulus, exponent);
        KeyFactory kf = KeyFactory.getInstance("RSA");
        return kf.generatePublic(keySpec);
    }

    private String call(String uri) throws IOException {
        try (CloseableHttpClient client = HttpClientBuilder.create().build()) {
            HttpGet request = new HttpGet(uri);
            HttpResponse httpResponse = client.execute(request);
            return IOUtils.toString(httpResponse.getEntity().getContent(), UTF_8);
        }
    }
}
