package com.xebialabs.platform.sso.oidc.policy.impl;

import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.util.Assert;

import com.xebialabs.platform.sso.oidc.exceptions.InvalidRoleClaimsListException;
import com.xebialabs.platform.sso.oidc.policy.ClaimsToGrantedAuthoritiesPolicy;

import static java.lang.String.format;
import static org.slf4j.LoggerFactory.getLogger;

/**
 * Provides default behavior for case when a group claim is bound to granted authorities.
 */
public class DefaultClaimsToGrantedAuthoritiesPolicy implements ClaimsToGrantedAuthoritiesPolicy {
    private static final Logger logger = getLogger(DefaultClaimsToGrantedAuthoritiesPolicy.class);

    // Do the way keycloak does it as Platform SSO uses keycloak as backend
    // A character in a claim component is either a literal character escaped by a backslash (\., \\, \_, \q, etc.)
    // or any character other than backslash (escaping) and dot (claim component separator)
    private static final Pattern CLAIM_COMPONENT = Pattern.compile("^((\\\\.|[^\\\\.])+?)\\.");

    private static final Pattern BACKSLASH_CHARACTER = Pattern.compile("\\\\(.)");

    private final String rolesClaimName;

    public DefaultClaimsToGrantedAuthoritiesPolicy(String rolesClaimName) {
        Assert.hasText(rolesClaimName, "rolesClaimName must contain a property name");
        this.rolesClaimName = rolesClaimName;
    }

    public List<GrantedAuthority> claimsToGrantedAuthorities(Map<String, Object> oidcClaims) {
        List<String> claimPath = splitClaimPath(rolesClaimName);
        logger.debug("Got claim path: {} for rolesClaimName: {}", claimPath, rolesClaimName);
        try {
            List<String> extractedRoles = getRolesFromOidcClaims(oidcClaims, claimPath);
            logger.debug("Extracted groups/roles: {}", extractedRoles);

            return extractedRoles.stream().map(SimpleGrantedAuthority::new).collect(Collectors.toList());
        } catch (ClassCastException e) {
            throw new InvalidRoleClaimsListException(format("Role claims name in property [%s] contained non string values", rolesClaimName));
        }
    }

    /**
     * Splits the claim name and returns list of claim path in the claim set. It will take care of escape characters.
     * For example, for claimName as realm.realm_roles.roles, returns the List(realm, realm_roles, roles)
     *
     * @param claimName the claim name path
     * @return the claim path as list
     */
    private List<String> splitClaimPath(String claimName) {
        final LinkedList<String> claimComponents = new LinkedList<>();
        Matcher m = CLAIM_COMPONENT.matcher(claimName);
        int start = 0;
        while (m.find()) {
            claimComponents.add(BACKSLASH_CHARACTER.matcher(m.group(1)).replaceAll("$1"));
            start = m.end();
            // This is necessary to match the start of region as the start of string as determined by ^
            m.region(start, claimName.length());
        }
        if (claimName.length() > start) {
            claimComponents.add(BACKSLASH_CHARACTER.matcher(claimName.substring(start)).replaceAll("$1"));
        }
        return claimComponents;
    }

    /**
     * Returns the list of roles for given path from oidc claims
     *
     * @param oidcClaims the oidc claims
     * @param claimPath the claim path
     * @return the list of roles
     */
    private List<String> getRolesFromOidcClaims(final Map<String, Object> oidcClaims, final List<String> claimPath) {
        List<String> roles = new ArrayList<>();
        int length = claimPath.size();
        int i = 0;
        Map<String, Object> jsonObject = oidcClaims;
        for (String path : claimPath) {
            i++;
            if (i == length) {
                Object last = jsonObject.getOrDefault(path, new ArrayList<>());

                if (last instanceof String) {
                    roles.add((String) last);
                } else {
                    if (!(last instanceof List)) {
                        throw new InvalidRoleClaimsListException(format("Role claims in property [%s] are not a list", rolesClaimName));
                    }
                    roles.addAll((List) last);
                }
            } else {
                Map<String, Object> nested = (Map<String, Object>) jsonObject.get(path);
                if (nested == null) {
                    break;
                }
                jsonObject = nested;
            }
        }
        return roles;
    }

}
