package com.xebialabs.xlrelease.auth.oidc.web.handlers;

import java.io.IOException;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.session.SessionAuthenticationException;

import static com.xebialabs.xlrelease.auth.oidc.config.OpenIdConnectConfig.OIDC_LOGIN_PATH_NAME;
import static com.xebialabs.xlrelease.auth.oidc.config.OpenIdConnectConfig.OIDC_PROCESSING_URL;
import static com.xebialabs.xlrelease.auth.oidc.web.XlReleaseLoginFormFilter.ERROR_PARAMETER_NAME;
import static com.xebialabs.xlrelease.auth.oidc.web.XlReleaseLoginFormFilter.LOGIN_PATH_NAME;

/**
 * This failure handler is added for backward compatibility to redirect user to IdP Login page after logout
 */
public class OidcLoginFailureHandler implements AuthenticationFailureHandler {

    private final Set<String> NON_STANDARD_OAUTH2_ERROR_CODES = Stream.of(
            "authorization_request_not_found" //org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter.AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE
    ).collect(Collectors.toSet());

    @Override
    public void onAuthenticationFailure(final HttpServletRequest request,
                                        final HttpServletResponse response,
                                        final AuthenticationException exception) throws IOException, ServletException {

        boolean isOAuth2Error = exception instanceof OAuth2AuthenticationException && !NON_STANDARD_OAUTH2_ERROR_CODES.contains(((OAuth2AuthenticationException) exception).getError().getErrorCode());
        boolean isRedirectError = request.getRequestURI().contains(OIDC_PROCESSING_URL) && null == request.getQueryString();

        if (!isRedirectError &&
                (exception instanceof BadCredentialsException || exception instanceof SessionAuthenticationException || exception instanceof DisabledException || isOAuth2Error)) {
            response.sendRedirect(LOGIN_PATH_NAME + "?" + ERROR_PARAMETER_NAME + "=" + exception.getMessage());
        } else {
            response.sendRedirect(OIDC_LOGIN_PATH_NAME);
        }
    }
}
