package com.xebialabs.platform.sso.oidc.web;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.web.logout.OidcClientInitiatedLogoutSuccessHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

/**
 * A logout success handler for initiating OIDC logout through the user agent.
 *
 * This class is similar to {@link OidcClientInitiatedLogoutSuccessHandler}. This is customised to prevent
 * encoding of URIs
 *
 * We can remove this class once this issue has been fixed by spring-security team.
 *
 */
public class CustomOidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
    private final ClientRegistrationRepository clientRegistrationRepository;

    private String postLogoutRedirectUri;

    public CustomOidcClientInitiatedLogoutSuccessHandler(ClientRegistrationRepository clientRegistrationRepository) {
        Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
        this.clientRegistrationRepository = clientRegistrationRepository;
    }

    @Override
    protected String determineTargetUrl(HttpServletRequest request,
                                        HttpServletResponse response, Authentication authentication) {
        String targetUrl = null;
        if (authentication instanceof OAuth2AuthenticationToken && authentication.getPrincipal() instanceof OidcUser) {
            String registrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
            ClientRegistration clientRegistration = this.clientRegistrationRepository
                    .findByRegistrationId(registrationId);
            String endSessionEndpoint = this.endSessionEndpoint(clientRegistration);
            if (endSessionEndpoint != null) {
                String idToken = idToken(authentication);
                URI logoutRedirectUri = postLogoutRedirectUri(request, clientRegistration);
                targetUrl = endpointUri(endSessionEndpoint, idToken, logoutRedirectUri);
            }
        }
        return (targetUrl != null) ? targetUrl : super.determineTargetUrl(request, response);
    }

    private String endSessionEndpoint(ClientRegistration clientRegistration) {
        String result = null;
        if (clientRegistration != null) {
            Object endSessionEndpoint = clientRegistration.getProviderDetails().getConfigurationMetadata()
                    .get("end_session_endpoint");
            if (endSessionEndpoint != null) {
                result = endSessionEndpoint.toString();
            }
        }

        return result;
    }

    private String idToken(Authentication authentication) {
        return ((OidcUser) authentication.getPrincipal()).getIdToken().getTokenValue();
    }

    private URI postLogoutRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
        if (this.postLogoutRedirectUri == null) {
            return null;
        }
        UriComponents uriComponents = UriComponentsBuilder
                .fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
                .replacePath(request.getContextPath())
                .replaceQuery(null)
                .fragment(null)
                .build();

        Map<String, String> uriVariables = new HashMap<>();
        String scheme = uriComponents.getScheme();
        uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
        uriVariables.put("baseUrl", uriComponents.toUriString());

        String host = uriComponents.getHost();
        uriVariables.put("baseHost", (host != null) ? host : "");

        String path = uriComponents.getPath();
        uriVariables.put("basePath", (path != null) ? path : "");

        int port = uriComponents.getPort();
        uriVariables.put("basePort", (port == -1) ? "" : ":" + port);

        uriVariables.put("registrationId", clientRegistration.getRegistrationId());

        return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri)
                .buildAndExpand(uriVariables)
                .toUri();
    }


    private String endpointUri(String endSessionEndpoint, String idToken, URI postLogoutRedirectUri) {
        UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(endSessionEndpoint);
        builder.queryParam("id_token_hint", idToken);
        if (postLogoutRedirectUri != null) {
            builder.queryParam("post_logout_redirect_uri", postLogoutRedirectUri);
        }
        return builder.encode(StandardCharsets.UTF_8).build().toUriString();
    }

    /**
     * Set the post logout redirect uri template to use. Supports the {@code "{baseUrl}"}
     * placeholder, for example:
     *
     * <pre>
     * 	handler.setPostLogoutRedirectUriTemplate("{baseUrl}");
     * </pre>
     * <p>
     * will make so that {@code post_logout_redirect_uri} will be set to the base url for the client
     * application.
     *
     * @param postLogoutRedirectUri - A template for creating the {@code post_logout_redirect_uri}
     *                              query parameter
     * @since 5.3
     */
    public void setPostLogoutRedirectUri(String postLogoutRedirectUri) {
        Assert.notNull(postLogoutRedirectUri, "postLogoutRedirectUri cannot be null");
        this.postLogoutRedirectUri = postLogoutRedirectUri;
    }
}
