/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.genai.util.aws;

import java.io.Serializable;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.TemporalAccessor;
import java.util.ArrayList;
import java.util.HexFormat;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.regex.Pattern;
import org.apache.commons.text.StringSubstitutor;
import org.eclipse.collections.api.block.procedure.Procedure2;
import org.eclipse.collections.api.multimap.Multimap;
import org.eclipse.collections.api.multimap.MutableMultimap;
import org.eclipse.collections.impl.factory.Multimaps;
import org.neo4j.genai.util.Hashing;
import org.neo4j.genai.util.aws.URLUtils;

public class AwsSignatureV4HeaderGenerator {
    static final HexFormat HEX_FORMAT = HexFormat.of();
    private static final DateTimeFormatter DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'", Locale.ROOT);
    private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd", Locale.ROOT);
    private static final String SCOPE_TEMPLATE = "${date}/${region}/bedrock/aws4_request";
    private static final String ALGORITHM = "AWS4-HMAC-SHA256";
    private static final String AUTHORIZATION_TEMPLATE = "AWS4-HMAC-SHA256 Credential=${accessKeyId}/${scope}, SignedHeaders=${signedHeaders}, Signature=${signature}";
    private final String region;
    private final URLUtils.CanonicalURIComponents canonicalEndpointComponents;
    private final String body;
    private final MutableMultimap<String, String> headers;

    public AwsSignatureV4HeaderGenerator(String region, URI endpoint, String body, Multimap<String, String> requestProperties) {
        this.canonicalEndpointComponents = new URLUtils.CanonicalURIComponents(endpoint);
        this.region = region;
        this.body = body;
        this.headers = Multimaps.mutable.list.withAll(requestProperties);
    }

    public Multimap<String, String> generate(String accessKeyId, String secretAccessKey) {
        return this.generate(ZonedDateTime.now(ZoneOffset.UTC), accessKeyId, secretAccessKey);
    }

    Multimap<String, String> generate(TemporalAccessor time, String accessKeyId, String secretAccessKey) {
        String datetime = DATE_TIME_FORMATTER.format(time);
        String date = DATE_FORMATTER.format(time);
        this.headers.put((Object)"X-Amz-Date", (Object)datetime);
        String hashedPayload = HEX_FORMAT.formatHex(Hashing.sha256(this.body.getBytes(StandardCharsets.UTF_8)));
        this.headers.put((Object)"X-Amz-Content-Sha256", (Object)hashedPayload);
        Canonical canonical = new Canonical(this.canonicalEndpointComponents, hashedPayload, (Multimap<String, String>)this.headers);
        String scope = Canonical.scope(date, this.region);
        String hashedCanonicalRequest = canonical.hashedCanonicalRequest(datetime, scope);
        String signature = AwsSignatureV4HeaderGenerator.signature(hashedCanonicalRequest, secretAccessKey, date, this.region);
        String authorization = StringSubstitutor.replace((Object)AUTHORIZATION_TEMPLATE, Map.of("accessKeyId", accessKeyId, "scope", scope, "signedHeaders", canonical.signedHeaders(), "signature", signature));
        this.headers.put((Object)"Authorization", (Object)authorization);
        this.headers.removeAll((Object)"Host");
        return this.headers;
    }

    private static String signature(String payload, String secretAccessKey, String date, String region) {
        byte[] dateKey = Hashing.hmacSha256(("AWS4" + secretAccessKey).getBytes(StandardCharsets.UTF_8), date);
        byte[] regionKey = Hashing.hmacSha256(dateKey, region);
        byte[] serviceKey = Hashing.hmacSha256(regionKey, "bedrock");
        byte[] signingKey = Hashing.hmacSha256(serviceKey, "aws4_request");
        return HEX_FORMAT.formatHex(Hashing.hmacSha256(signingKey, payload));
    }

    static class Canonical {
        private static final Pattern WHITESPACE = Pattern.compile("\\s+");
        private final String canonicalRequest;
        private final String signedHeaders;

        Canonical(URLUtils.CanonicalURIComponents canonicalEndpointComponents, String hashedPayload, Multimap<String, String> headers) {
            String httpMethod = "POST";
            String canonicalURI = canonicalEndpointComponents.path();
            String canonicalQueryString = canonicalEndpointComponents.query();
            SortedMap<String, List<String>> canonicalHeaders = Canonical.canonicalHeaders(headers);
            this.signedHeaders = Canonical.signedHeaders(canonicalHeaders);
            StringBuilder request = new StringBuilder();
            request.append("POST").append('\n').append(canonicalURI).append('\n').append(canonicalQueryString).append('\n');
            Canonical.addCanonicalHeadersString(request, canonicalHeaders);
            request.append('\n').append(this.signedHeaders).append('\n').append(hashedPayload);
            this.canonicalRequest = request.toString();
        }

        String hashedCanonicalRequest(String datetime, String scope) {
            String hash = HEX_FORMAT.formatHex(Hashing.sha256(this.canonicalRequest.getBytes(StandardCharsets.UTF_8)));
            return "AWS4-HMAC-SHA256\n" + datetime + "\n" + scope + "\n" + hash;
        }

        String signedHeaders() {
            return this.signedHeaders;
        }

        static String scope(String date, String region) {
            return StringSubstitutor.replace((Object)AwsSignatureV4HeaderGenerator.SCOPE_TEMPLATE, Map.of("date", date, "region", region));
        }

        static SortedMap<String, List<String>> canonicalHeaders(Multimap<String, String> headers) {
            TreeMap<String, List<String>> orderedHeaders = new TreeMap<String, List<String>>();
            headers.forEachKeyValue((Procedure2 & Serializable)(key, value) -> {
                String lowerCaseKey = Canonical.compressWhitespace(key.toLowerCase(Locale.ROOT));
                String trimmedValue = Canonical.compressWhitespace(value).trim();
                orderedHeaders.computeIfAbsent(lowerCaseKey, k -> new ArrayList()).add(trimmedValue);
            });
            if (!orderedHeaders.containsKey("host")) {
                throw new IllegalArgumentException("HTTP Host header is required");
            }
            return orderedHeaders;
        }

        private static String compressWhitespace(String value) {
            return WHITESPACE.matcher(value).replaceAll(" ");
        }

        static void addCanonicalHeadersString(StringBuilder request, SortedMap<String, List<String>> canonicalHeaders) {
            canonicalHeaders.forEach((key, values) -> {
                request.append((String)key).append(':');
                boolean first = true;
                for (String value : values) {
                    if (first) {
                        first = false;
                    } else {
                        request.append(',');
                    }
                    request.append(value);
                }
                request.append('\n');
            });
        }

        static String signedHeaders(SortedMap<String, ?> headers) {
            StringBuilder sb = new StringBuilder();
            boolean first = true;
            for (String key : headers.keySet()) {
                if (first) {
                    first = false;
                } else {
                    sb.append(';');
                }
                sb.append(key);
            }
            return sb.toString();
        }
    }
}

