/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.saml2.provider.service.web;

import jakarta.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.function.Function;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;
import net.shibboleth.utilities.java.support.xml.ParserPool;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
import org.springframework.http.HttpMethod;
import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

public final class OpenSamlAuthenticationTokenConverter
implements AuthenticationConverter {
    private static final Base64.Decoder BASE64;
    private static final Base64Checker BASE_64_CHECKER;
    private final RelyingPartyRegistrationRepository registrations;
    private RequestMatcher requestMatcher = new OrRequestMatcher(new RequestMatcher[]{new AntPathRequestMatcher("/login/saml2/sso/{registrationId}"), new AntPathRequestMatcher("/login/saml2/sso")});
    private final ParserPool parserPool;
    private final ResponseUnmarshaller unmarshaller;
    private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;

    public OpenSamlAuthenticationTokenConverter(RelyingPartyRegistrationRepository registrations) {
        Assert.notNull((Object)registrations, (String)"relyingPartyRegistrationRepository cannot be null");
        XMLObjectProviderRegistry registry = (XMLObjectProviderRegistry)ConfigurationService.get(XMLObjectProviderRegistry.class);
        this.parserPool = registry.getParserPool();
        this.unmarshaller = (ResponseUnmarshaller)XMLObjectProviderRegistrySupport.getUnmarshallerFactory().getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
        this.registrations = registrations;
        this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
    }

    public Saml2AuthenticationToken convert(HttpServletRequest request) {
        String serialized = request.getParameter("SAMLResponse");
        if (serialized == null) {
            return null;
        }
        RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
        if (!result.isMatch()) {
            return null;
        }
        Saml2AuthenticationToken token = this.tokenByAuthenticationRequest(request);
        if (token == null) {
            token = this.tokenByRegistrationId(request, result);
        }
        if (token == null) {
            token = this.tokenByEntityId(request);
        }
        return token;
    }

    private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) {
        AbstractSaml2AuthenticationRequest authenticationRequest = this.loadAuthenticationRequest(request);
        if (authenticationRequest == null) {
            return null;
        }
        String registrationId = authenticationRequest.getRelyingPartyRegistrationId();
        RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
        return this.tokenByRegistration(request, registration, authenticationRequest);
    }

    private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request, RequestMatcher.MatchResult result) {
        String registrationId = (String)result.getVariables().get("registrationId");
        if (registrationId == null) {
            return null;
        }
        RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
        return this.tokenByRegistration(request, registration, null);
    }

    private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) {
        String serialized = request.getParameter("SAMLResponse");
        String decoded = new String(this.samlDecode(serialized), StandardCharsets.UTF_8);
        Response response = this.parse(decoded);
        String issuer = response.getIssuer().getValue();
        RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
        return this.tokenByRegistration(request, registration, null);
    }

    private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request, RelyingPartyRegistration registration, AbstractSaml2AuthenticationRequest authenticationRequest) {
        if (registration == null) {
            return null;
        }
        String serialized = request.getParameter("SAMLResponse");
        String decoded = this.inflateIfRequired(request, this.samlDecode(serialized));
        RelyingPartyRegistrationPlaceholderResolvers.UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
        registration = registration.mutate().entityId(resolver.resolve(registration.getEntityId())).assertionConsumerServiceLocation(resolver.resolve(registration.getAssertionConsumerServiceLocation())).build();
        return new Saml2AuthenticationToken(registration, decoded, authenticationRequest);
    }

    public void setAuthenticationRequestRepository(Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
        Assert.notNull(authenticationRequestRepository, (String)"authenticationRequestRepository cannot be null");
        this.loader = authenticationRequestRepository::loadAuthenticationRequest;
    }

    public void setRequestMatcher(RequestMatcher requestMatcher) {
        Assert.notNull((Object)requestMatcher, (String)"requestMatcher cannot be null");
        this.requestMatcher = requestMatcher;
    }

    private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
        return this.loader.apply(request);
    }

    private String inflateIfRequired(HttpServletRequest request, byte[] b) {
        if (HttpMethod.GET.matches(request.getMethod())) {
            return this.samlInflate(b);
        }
        return new String(b, StandardCharsets.UTF_8);
    }

    private byte[] samlDecode(String base64EncodedPayload) {
        try {
            BASE_64_CHECKER.checkAcceptable(base64EncodedPayload);
            return BASE64.decode(base64EncodedPayload);
        }
        catch (Exception ex) {
            throw new Saml2AuthenticationException(new Saml2Error("invalid_response", "Failed to decode SAMLResponse"), ex);
        }
    }

    private String samlInflate(byte[] b) {
        try {
            ByteArrayOutputStream out = new ByteArrayOutputStream();
            InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
            inflaterOutputStream.write(b);
            inflaterOutputStream.finish();
            return out.toString(StandardCharsets.UTF_8.name());
        }
        catch (Exception ex) {
            throw new Saml2AuthenticationException(new Saml2Error("invalid_response", "Unable to inflate string"), ex);
        }
    }

    private Response parse(String request) throws Saml2Exception {
        try {
            Document document = this.parserPool.parse((InputStream)new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8)));
            Element element = document.getDocumentElement();
            return (Response)this.unmarshaller.unmarshall(element);
        }
        catch (Exception ex) {
            throw new Saml2Exception("Failed to deserialize LogoutRequest", ex);
        }
    }

    static {
        OpenSamlInitializationService.initialize();
        BASE64 = Base64.getMimeDecoder();
        BASE_64_CHECKER = new Base64Checker();
    }

    static class Base64Checker {
        private static final int[] values = Base64Checker.genValueMapping();

        Base64Checker() {
        }

        private static int[] genValueMapping() {
            byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".getBytes(StandardCharsets.ISO_8859_1);
            int[] values = new int[256];
            Arrays.fill(values, -1);
            for (int i = 0; i < alphabet.length; ++i) {
                values[alphabet[i] & 0xFF] = i;
            }
            return values;
        }

        boolean isAcceptable(String s) {
            int goodChars = 0;
            int lastGoodCharVal = -1;
            for (int i = 0; i < s.length(); ++i) {
                int val = values[0xFF & s.charAt(i)];
                if (val == -1) continue;
                lastGoodCharVal = val;
                ++goodChars;
            }
            switch (goodChars % 4) {
                case 0: {
                    return true;
                }
                case 2: {
                    return (lastGoodCharVal & 0xF) == 0;
                }
                case 3: {
                    return (lastGoodCharVal & 3) == 0;
                }
            }
            return false;
        }

        void checkAcceptable(String ins) {
            if (!this.isAcceptable(ins)) {
                throw new IllegalArgumentException("Unaccepted Encoding");
            }
        }
    }
}

