package org.springframework.security.saml2.provider.service.web;

import jakarta.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Objects;
import java.util.function.Function;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
import org.springframework.http.HttpMethod;
import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.core.Saml2ErrorCodes;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;

/* loaded from: input_file:BOOT-INF/lib/spring-security-saml2-service-provider-6.3.7.jar:org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.class */
public final class OpenSamlAuthenticationTokenConverter implements AuthenticationConverter {
    private static final Base64.Decoder BASE64;
    private static final Base64Checker BASE_64_CHECKER;
    private final RelyingPartyRegistrationRepository registrations;
    private RequestMatcher requestMatcher = new OrRequestMatcher(new AntPathRequestMatcher(Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI), new AntPathRequestMatcher("/login/saml2/sso"));
    private final ParserPool parserPool;
    private final ResponseUnmarshaller unmarshaller;
    private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:BOOT-INF/lib/spring-security-saml2-service-provider-6.3.7.jar:org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter$Base64Checker.class */
    public static class Base64Checker {
        private static final int[] values = genValueMapping();

        Base64Checker() {
        }

        private static int[] genValueMapping() {
            byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".getBytes(StandardCharsets.ISO_8859_1);
            int[] iArr = new int[256];
            Arrays.fill(iArr, -1);
            for (int i = 0; i < bytes.length; i++) {
                iArr[bytes[i] & 255] = i;
            }
            return iArr;
        }

        boolean isAcceptable(String str) {
            int i = 0;
            int i2 = -1;
            for (int i3 = 0; i3 < str.length(); i3++) {
                int i4 = values[255 & str.charAt(i3)];
                if (i4 != -1) {
                    i2 = i4;
                    i++;
                }
            }
            switch (i % 4) {
                case 0:
                    return true;
                case 1:
                default:
                    return false;
                case 2:
                    return (i2 & 15) == 0;
                case 3:
                    return (i2 & 3) == 0;
            }
        }

        void checkAcceptable(String str) {
            if (!isAcceptable(str)) {
                throw new IllegalArgumentException("Unaccepted Encoding");
            }
        }
    }

    public OpenSamlAuthenticationTokenConverter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
        Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
        this.parserPool = ((XMLObjectProviderRegistry) ConfigurationService.get(XMLObjectProviderRegistry.class)).getParserPool();
        this.unmarshaller = (ResponseUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory().getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
        this.registrations = relyingPartyRegistrationRepository;
        HttpSessionSaml2AuthenticationRequestRepository httpSessionSaml2AuthenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
        this.loader = httpSessionSaml2AuthenticationRequestRepository::loadAuthenticationRequest;
    }

    @Override // org.springframework.security.web.authentication.AuthenticationConverter
    public Saml2AuthenticationToken convert(HttpServletRequest httpServletRequest) {
        if (httpServletRequest.getParameter(Saml2ParameterNames.SAML_RESPONSE) == null) {
            return null;
        }
        RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(httpServletRequest);
        if (!matcher.isMatch()) {
            return null;
        }
        Saml2AuthenticationToken saml2AuthenticationToken = tokenByAuthenticationRequest(httpServletRequest);
        if (saml2AuthenticationToken == null) {
            saml2AuthenticationToken = tokenByRegistrationId(httpServletRequest, matcher);
        }
        if (saml2AuthenticationToken == null) {
            saml2AuthenticationToken = tokenByEntityId(httpServletRequest);
        }
        return saml2AuthenticationToken;
    }

    private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest httpServletRequest) {
        AbstractSaml2AuthenticationRequest loadAuthenticationRequest = loadAuthenticationRequest(httpServletRequest);
        if (loadAuthenticationRequest == null) {
            return null;
        }
        return tokenByRegistration(httpServletRequest, this.registrations.findByRegistrationId(loadAuthenticationRequest.getRelyingPartyRegistrationId()), loadAuthenticationRequest);
    }

    private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest httpServletRequest, RequestMatcher.MatchResult matchResult) {
        String str = matchResult.getVariables().get(DefaultServerOAuth2AuthorizationRequestResolver.DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME);
        if (str == null) {
            return null;
        }
        return tokenByRegistration(httpServletRequest, this.registrations.findByRegistrationId(str), null);
    }

    private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest httpServletRequest) {
        return tokenByRegistration(httpServletRequest, this.registrations.findUniqueByAssertingPartyEntityId(parse(new String(samlDecode(httpServletRequest.getParameter(Saml2ParameterNames.SAML_RESPONSE)), StandardCharsets.UTF_8)).getIssuer().getValue()), null);
    }

    private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest httpServletRequest, RelyingPartyRegistration relyingPartyRegistration, AbstractSaml2AuthenticationRequest abstractSaml2AuthenticationRequest) {
        if (relyingPartyRegistration == null) {
            return null;
        }
        String inflateIfRequired = inflateIfRequired(httpServletRequest, samlDecode(httpServletRequest.getParameter(Saml2ParameterNames.SAML_RESPONSE)));
        RelyingPartyRegistrationPlaceholderResolvers.UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(httpServletRequest, relyingPartyRegistration);
        return new Saml2AuthenticationToken(relyingPartyRegistration.mutate().entityId(uriResolver.resolve(relyingPartyRegistration.getEntityId())).assertionConsumerServiceLocation(uriResolver.resolve(relyingPartyRegistration.getAssertionConsumerServiceLocation())).build(), inflateIfRequired, abstractSaml2AuthenticationRequest);
    }

    public void setAuthenticationRequestRepository(Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> saml2AuthenticationRequestRepository) {
        Assert.notNull(saml2AuthenticationRequestRepository, "authenticationRequestRepository cannot be null");
        Objects.requireNonNull(saml2AuthenticationRequestRepository);
        this.loader = saml2AuthenticationRequestRepository::loadAuthenticationRequest;
    }

    public void setRequestMatcher(RequestMatcher requestMatcher) {
        Assert.notNull(requestMatcher, "requestMatcher cannot be null");
        this.requestMatcher = requestMatcher;
    }

    private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest httpServletRequest) {
        return this.loader.apply(httpServletRequest);
    }

    private String inflateIfRequired(HttpServletRequest httpServletRequest, byte[] bArr) {
        return HttpMethod.GET.matches(httpServletRequest.getMethod()) ? samlInflate(bArr) : new String(bArr, StandardCharsets.UTF_8);
    }

    private byte[] samlDecode(String str) {
        try {
            BASE_64_CHECKER.checkAcceptable(str);
            return BASE64.decode(str);
        } catch (Exception e) {
            throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), e);
        }
    }

    private String samlInflate(byte[] bArr) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(byteArrayOutputStream, new Inflater(true));
            inflaterOutputStream.write(bArr);
            inflaterOutputStream.finish();
            return byteArrayOutputStream.toString(StandardCharsets.UTF_8.name());
        } catch (Exception e) {
            throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), e);
        }
    }

    private Response parse(String str) throws Saml2Exception {
        try {
            return (Response) this.unmarshaller.unmarshall(this.parserPool.parse(new ByteArrayInputStream(str.getBytes(StandardCharsets.UTF_8))).getDocumentElement());
        } catch (Exception e) {
            throw new Saml2Exception("Failed to deserialize LogoutRequest", e);
        }
    }

    static {
        OpenSamlInitializationService.initialize();
        BASE64 = Base64.getMimeDecoder();
        BASE_64_CHECKER = new Base64Checker();
    }
}
