/*
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 * Copyright (c) 2023-2025 Jeremy Long. All Rights Reserved.
 */
package io.github.jeremylong.openvulnerability.client.nvd;

import org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy;
import org.apache.hc.client5.http.utils.DateUtils;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.HttpRequest;
import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.protocol.HttpContext;
import org.apache.hc.core5.util.TimeValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;

/**
 * Implements a back-off delay retry strategy that honors the retry-after header.
 */
public class NvdApiRetryStrategy extends DefaultHttpRequestRetryStrategy {

    /**
     * Reference to the logger.
     */
    private static final Logger LOG = LoggerFactory.getLogger(NvdApiRetryStrategy.class);

    /**
     * Maximum number of allowed retries.
     */
    private final int maxRetries;

    /**
     * Retry interval between subsequent retries in milliseconds.
     */
    private final long delay;

    private static final String RETRY_MESSAGE_TEMPLATE = "NVD API request failures are occurring; retrying request for the {} time";

    /**
     * Constructs a new NVD API retry strategy.
     *
     * @param maxRetries the maximum number of retry attempts
     * @param delay the delay in milliseconds between retries
     */
    public NvdApiRetryStrategy(int maxRetries, long delay) {
        super(maxRetries, TimeValue.of(delay, TimeUnit.MILLISECONDS), new ArrayList<>(),
                Arrays.asList(429, 502, 503, 504));
        this.maxRetries = maxRetries;
        this.delay = delay;
    }

    @Override
    public boolean retryRequest(@Nonnull HttpRequest request, @Nonnull IOException exception, int execCount,
            HttpContext context) {
        logRetryState(request, exception, execCount);
        return super.retryRequest(request, exception, execCount, context);
    }

    @Override
    public boolean retryRequest(HttpResponse response, int execCount, HttpContext context) {
        if (execCount >= (maxRetries / 2)) {
            LOG.warn(RETRY_MESSAGE_TEMPLATE, toOrdinal(execCount));
        } else if (execCount > 1) {
            LOG.debug("Retrying request {} time", toOrdinal(execCount));
        }
        return super.retryRequest(response, execCount, context);
    }

    private void logRetryState(@Nonnull HttpRequest request, @Nonnull IOException exception, int execCount) {
        if (execCount >= (maxRetries / 2)) {
            LOG.warn(RETRY_MESSAGE_TEMPLATE, toOrdinal(execCount));
            if (LOG.isDebugEnabled()) {
                LOG.debug("NVD API request failures with exception : {}. Message: {}", exception.getClass().getName(),
                        exception.getMessage());
            }
        } else if (execCount > 1) {
            LOG.warn("Retrying request {} : {} time", request.getRequestUri(), toOrdinal(execCount));
            if (LOG.isDebugEnabled()) {
                LOG.debug("Retrying request with exception {}. Message: {}", exception.getClass().getName(),
                        exception.getMessage());
            }
        }
    }

    @Override
    public TimeValue getRetryInterval(final HttpResponse response, final int execCount, final HttpContext context) {
        TimeValue value;
        if (execCount < maxRetries / 2) {
            value = TimeValue.of(delay * execCount, TimeUnit.MILLISECONDS);
        } else {
            value = TimeValue.of(delay * execCount / 2, TimeUnit.MILLISECONDS);
        }
        LOG.debug("Calculated retry interval in {} ms with execCount of {}", value.toMilliseconds(), execCount);
        // check retry after header
        final Header header = response.getFirstHeader(HttpHeaders.RETRY_AFTER);
        if (header != null) {
            TimeValue retryAfter = null;
            final String headerValue = header.getValue();
            try {
                retryAfter = TimeValue.ofSeconds(Long.parseLong(headerValue));
                LOG.debug("Retry-After header value: {} ms", retryAfter.toMilliseconds());
            } catch (final NumberFormatException ignore) {
                final Instant retryAfterDate = DateUtils.parseStandardDate(headerValue);
                if (retryAfterDate != null) {
                    retryAfter = TimeValue.ofMilliseconds(retryAfterDate.toEpochMilli() - System.currentTimeMillis());
                    LOG.debug("Failed to parse value; Retry-After header value: {} ms", retryAfter.toMilliseconds());
                }
            }
            if (TimeValue.isPositive(retryAfter) && retryAfter.compareTo(value) < 0) {
                LOG.debug("Using Retry-After header value: {} ms", retryAfter.toMilliseconds());
                return retryAfter;
            }
        }
        LOG.debug("Using calculated retry interval: {} ms", value.toMilliseconds());
        return value;
    }

    /**
     * Converts a number to an ordinal string.
     *
     * @param number the number to convert
     * @return the ordinal string
     */
    public static String toOrdinal(int number) {
        String[] suffixes = new String[]{"th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th"};
        switch (number % 100) {
            case 11:
            case 12:
            case 13:
                return number + "th";
            default:
                return number + suffixes[number % 10];
        }
    }
}
