/*
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 * Copyright (c) 2022-2024 Jeremy Long. All Rights Reserved.
 */
package io.github.jeremylong.openvulnerability.client.nvd;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import io.github.jeremylong.openvulnerability.client.HttpAsyncClientSupplier;
import io.github.jeremylong.openvulnerability.client.PagedDataSource;
import org.apache.hc.client5.http.async.methods.SimpleHttpRequest;
import org.apache.hc.client5.http.async.methods.SimpleHttpResponse;
import org.apache.hc.client5.http.async.methods.SimpleRequestBuilder;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.NameValuePair;
import org.apache.hc.core5.net.URIBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

/**
 * A simple client for the NVD CVE API. Use the NvdCveClientBuilder with the desired filters to build the client and
 * then iterate over the results:
 *
 * <pre>
 * try (NvdCveClient api = NvdCveClientBuilder.aNvdCveApi().build()) {
 *     while (api.hasNext()) {
 *         Collection&lt;DefCveItem&gt; items = api.next();
 *     }
 * }
 * </pre>
 *
 * @author Jeremy Long
 * @see <a href="https://nvd.nist.gov/developers/vulnerabilities">NVD CVE API</a>
 */
public class NvdCveClient implements PagedDataSource<DefCveItem> {

    /**
     * Reference to the logger.
     */
    private static final Logger LOG = LoggerFactory.getLogger(NvdCveClient.class);
    /**
     * The default endpoint for the NVD CVE API.
     */
    private final static String DEFAULT_ENDPOINT = "https://services.nvd.nist.gov/rest/json/cves/2.0";
    /**
     * The header name for the NVD API Key.
     */
    private static final String API_KEY_NAME = "apiKey";
    /**
     * The NVD API key; can be null if a key is not used.
     */
    private final String apiKey;
    /**
     * The NVD API endpoint used to call the NVD CVE API.
     */
    private final String endpoint;
    /**
     * Jackson object mapper.
     */
    private final ObjectMapper objectMapper;
    /**
     * The user agent to append to the default open-vulnerability-client's user-agent string
     */
    private final String userAgent;
    /**
     * The rate limited HTTP client for calling the NVD APIs.
     */
    private List<RateLimitedClient> clients;
    /**
     * The list of future responses.
     */
    private final List<Future<RateLimitedCall>> futures = new ArrayList<>();
    /**
     * The map of indexes to retrieve from the NVD and their retry count. This is used to retry when failures have
     * occurred on a single index.
     */
    private final Map<Integer, Integer> indexesToRetrieve = new HashMap<>();
    /**
     * Flag indicating if the first call has been made.
     */
    private boolean firstCall = true;
    /**
     * The number of results per page.
     */
    private int resultsPerPage = 2000;
    /**
     * The total results from the NVD CVE API call.
     */
    private int totalAvailable = -1;
    /**
     * The maximum number of pages to retrieve from the NVD API.
     */
    private final int maxPageCount;
    /**
     * A list of filters to apply to the request.
     */
    private List<NameValuePair> filters;
    /**
     * The last HTTP Status Code returned by the API.
     */
    private int lastStatusCode = 200;
    /**
     * The last lastModified timestamp from the NVD data processed.
     */
    private ZonedDateTime lastUpdated = null;
    /**
     * The version of the client.
     */
    private String version = "unknown";

    /**
     * Constructs a new NVD CVE API client.
     *
     * @param apiKey the api key; can be null
     * @param endpoint the endpoint for the NVD CVE API; if null the default endpoint is used
     * @param threadCount the number of threads to use when calling the NVD API.
     * @param maxPageCount the maximum number of pages to retrieve from the NVD API.
     */
    NvdCveClient(String apiKey, String endpoint, int threadCount, int maxPageCount) {
        this(apiKey, endpoint, 0, threadCount, maxPageCount, 10, null);
    }

    /**
     * Constructs a new NVD CVE API client.
     *
     * @param apiKey the api key; can be null
     * @param endpoint the endpoint for the NVD CVE API; if null the default endpoint is used
     * @param threadCount the number of threads to use when calling the NVD API.
     * @param maxPageCount the maximum number of pages to retrieve from the NVD API.
     * @param maxRetryCount the maximum number of retries for 503 and 429 status code responses.
     */
    NvdCveClient(String apiKey, String endpoint, int threadCount, int maxPageCount, int maxRetryCount) {
        this(apiKey, endpoint, 0, threadCount, maxPageCount, maxRetryCount, null);
    }

    /**
     * Constructs a new NVD CVE API client.
     *
     * @param apiKey the api key; can be null
     * @param endpoint the endpoint for the NVD CVE API; if null the default endpoint is used
     * @param delay the delay in milliseconds between API calls on a single thread.
     * @param threadCount the number of threads to use when calling the NVD API.
     * @param maxPageCount the maximum number of pages to retrieve from the NVD API.
     * @param maxRetryCount the maximum number of retries for 503 and 429 status code responses.
     * @param httpClientSupplier supplier for custom HTTP clients; if {@code null} a default client will be used
     */
    NvdCveClient(String apiKey, String endpoint, long delay, int threadCount, int maxPageCount, int maxRetryCount,
            HttpAsyncClientSupplier httpClientSupplier) {
        this(apiKey, endpoint, delay, threadCount, maxPageCount, maxRetryCount, httpClientSupplier, null);
    }

    /**
     * Constructs a new NVD CVE API client.
     *
     * @param apiKey the api key; can be null
     * @param endpoint the endpoint for the NVD CVE API; if null the default endpoint is used
     * @param delay the delay in milliseconds between API calls on a single thread.
     * @param threadCount the number of threads to use when calling the NVD API.
     * @param maxPageCount the maximum number of pages to retrieve from the NVD API.
     * @param maxRetryCount the maximum number of retries for 503 and 429 status code responses.
     * @param httpClientSupplier supplier for custom HTTP clients; if {@code null} a default client will be used
     * @param userAgent the user agent to append to the default open-vulnerability-client's user-agent string
     */
    NvdCveClient(String apiKey, String endpoint, long delay, int threadCount, int maxPageCount, int maxRetryCount,
            HttpAsyncClientSupplier httpClientSupplier, String userAgent) {

        this.apiKey = apiKey;
        this.userAgent = userAgent;
        if (endpoint == null) {
            this.endpoint = DEFAULT_ENDPOINT;
        } else {
            this.endpoint = endpoint;
        }
        if (threadCount <= 0) {
            threadCount = 1;
        }
        this.maxPageCount = maxPageCount;
        // configure the rate limit slightly higher than the published limits:
        // https://nvd.nist.gov/developers/start-here (see Rate Limits)

        RateMeter meter;
        if (apiKey == null) {
            if (threadCount > 1) {
                LOG.warn(
                        "No api key provided; as such the thread count has been reset to 1 instead of the requested {}",
                        threadCount);
                threadCount = 1;
            }
            meter = new RateMeter(5, 32500);
        } else {
            meter = new RateMeter(50, 32500);
        }
        clients = new ArrayList<>(threadCount);
        if (delay == 0) {
            delay = apiKey == null ? 6500 : 600;
        }
        for (int i = 0; i < threadCount; i++) {
            clients.add(new RateLimitedClient(maxRetryCount, delay, meter, httpClientSupplier));
        }
        objectMapper = new ObjectMapper();
        objectMapper.registerModule(new JavaTimeModule());

        try {
            Properties props = new Properties();
            props.load(getClass().getClassLoader().getResourceAsStream("version.properties"));
            version = props.getProperty("version");
        } catch (IOException e) {
            LOG.debug("Error loading version.properties", e);
        }
    }

    /**
     * Set the filter parameters for the NVD CVE API calls.
     *
     * @param filters the list of parameters used to filter the results in the API call
     */
    void setFilters(List<NameValuePair> filters) {
        this.filters = filters;
    }

    /**
     * The number of results per page; the default is 2000.
     *
     * @param resultsPerPage the number of results per page
     */
    void setResultsPerPage(int resultsPerPage) {
        this.resultsPerPage = resultsPerPage;
    }

    /**
     * Returns the last HTTP Status Code.
     *
     * @return the last HTTP Status Code
     */
    public int getLastStatusCode() {
        return lastStatusCode;
    }

    /**
     * Only available after the first call to `next()`; returns the total number of records that will be returned.
     *
     * @return the total number of records that will be returned
     */
    @Override
    public int getTotalAvailable() {
        return totalAvailable;
    }

    /**
     * Asynchronously calls the NVD CVE API.
     *
     * @param startIndex the start index to request
     * @return the future
     * @throws NvdApiException thrown if there is a problem calling the API
     */
    private Future<RateLimitedCall> callApi(int clientIndex, int startIndex) throws NvdApiException {
        try {
            URIBuilder uriBuilder = new URIBuilder(endpoint);
            if (filters != null) {
                uriBuilder.addParameters(filters);
            }
            uriBuilder.addParameter("resultsPerPage", Integer.toString(resultsPerPage));
            uriBuilder.addParameter("startIndex", Integer.toString(startIndex));
            final SimpleRequestBuilder builder = SimpleRequestBuilder.get();
            if (apiKey != null) {
                builder.addHeader(API_KEY_NAME, apiKey);
            }
            String ua = "open-vulnerability-client/" + version;
            if (userAgent != null) {
                ua += "; " + userAgent;
            }
            builder.addHeader("User-Agent", ua);
            URI uri = uriBuilder.build();
            LOG.debug("requesting URI: {}", uri.toString());
            final SimpleHttpRequest request = builder.setUri(uri).build();
            return clients.get(clientIndex).execute(request, clientIndex, startIndex);
        } catch (URISyntaxException e) {
            throw new NvdApiException(e);
        }
    }

    @Override
    public void close() {
        indexesToRetrieve.clear();
        if (futures.size() > 0) {
            for (Future<RateLimitedCall> future : futures) {
                if (!future.isDone()) {
                    future.cancel(true);
                }
            }
            futures.clear();
        }
        if (clients != null) {
            for (RateLimitedClient client : clients) {
                try {
                    client.close();
                } catch (Exception ex) {
                    LOG.debug("Error closing client during `close`", ex);
                }
            }
            clients = null;
        }
    }

    @Override
    public boolean hasNext() {
        if (lastStatusCode != 200) {
            return false;
        }
        if (firstCall) {
            return true;
        }
        if (futures.isEmpty() && !indexesToRetrieve.isEmpty()) {
            queueUnsuccessful();
        }
        return !futures.isEmpty();
    }

    /**
     * <p>
     * Retrieves the next entry. Note that even if `hasNext()` returns true it is possible that `next()` will return
     * null. This will generally only occur on the very first call.
     * </p>
     *
     * @return the next collection of CVE entries
     */
    @Override
    public Collection<DefCveItem> next() {
        return _next(0);
    }

    private Collection<DefCveItem> _next(int retryCount) {
        if (retryCount > 5) {
            throw new NvdApiRetryExceededException(
                    "NVD Update Failed: attempted to retrieve data from the NVD unsuccessfully five times.");
        }
        if (firstCall) {
            futures.add(callApi(0, 0));
        }
        String json;
        RateLimitedCall call;
        try {
            call = getCompletedFuture();
            if (call == null) {
                if (hasNext()) {
                    return _next(retryCount + 1);
                }
            } else {
                SimpleHttpResponse response = call.getResponse();
                if (response.getCode() == 200) {
                    LOG.debug("Content-Type Received: {}", response.getContentType());
                    json = new String(response.getBodyBytes(), StandardCharsets.UTF_8);

                    CveApiJson20 current;
                    try {
                        current = objectMapper.readValue(json, CveApiJson20.class);
                        this.indexesToRetrieve.remove(call.getStartIndex());
                    } catch (JsonMappingException e) {
                        LOG.debug("Error parsing NVD data", e);
                        // Fail fast on JSON parsing errors
                        throw new NvdApiException("Failed to parse NVD data", e);
                    } catch (JsonProcessingException e) {
                        LOG.debug("Error processing NVD data", e);
                        // Re-try on what might be temporarily streaming errors
                        return _next(retryCount + 1);
                    }
                    this.totalAvailable = current.getTotalResults();
                    lastUpdated = findLastUpdated(lastUpdated, current.getVulnerabilities());
                    if (firstCall) {
                        firstCall = false;
                        queueCalls();
                    }
                    if (futures.isEmpty() && !indexesToRetrieve.isEmpty()) {
                        queueUnsuccessful();
                    }
                    return current.getVulnerabilities();
                } else {
                    lastStatusCode = response.getCode();
                    LOG.debug("Status Code: {}", lastStatusCode);
                    LOG.debug("Reason: {}", response.getReasonPhrase());
                    LOG.debug("Response Headers:");
                    Header[] headers = response.getHeaders();
                    String msg = null;
                    for (Header header : headers) {
                        LOG.debug("Key : " + header.getName() + " ,Value : " + header.getValue());
                        if ("message".equals(header.getName())) {
                            msg = header.getValue();
                        }
                    }
                    LOG.debug("Response: {}", new String(response.getBodyBytes(), StandardCharsets.UTF_8));
                    if (msg != null) {
                        msg = msg.trim();
                        if (msg.contains("Invalid apiKey")) {
                            if (this.apiKey.length() > 30) {
                                String masked = String.format("Invalid API Key: %s-*****-%s",
                                        this.apiKey.substring(0, 5),
                                        this.apiKey.substring(this.apiKey.length() - 5, this.apiKey.length()));
                                throw new NvdApiException(masked);
                            }
                            String masked = String.format("Invalid API Key: %s-*****", this.apiKey.substring(0, 5));
                            throw new NvdApiException(masked);
                        } else if (msg.startsWith("resultsPerPage parameter cannot exceed")) {
                            this.resultsPerPage = parseResultsPerPage(msg);
                            LOG.warn(msg);
                            LOG.warn("NVD requested a lower resultsPerPage; settings to {}", this.resultsPerPage);
                            return _next(retryCount + 1);
                        }
                        throw new NvdApiException("NVD Returned Status Code: " + lastStatusCode + " - " + msg);
                    }
                    throw new NvdApiException("NVD Returned Status Code: " + lastStatusCode);
                }
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            close();
            throw new NvdApiException(e);
        } catch (ExecutionException e) {
            // in rare cases we get an error from the NVD - log the error and only fail if we retry too many times
            LOG.debug("Error retrieving the NVD data", e);
            if (hasNext()) {
                return _next(retryCount + 1);
            }
            close();
        }
        return null;
    }

    /**
     * Attempts to parse the resultsPerPage error message to determine the maximum number of results per page.
     * @param msg the error message from the NVD
     * @return the parsed results per page if succesful; otherwise the previously confiuged results per page
     */
    protected int parseResultsPerPage(String msg) {
        String value = msg;
        if (value.endsWith(".")) {
            value = value.substring(0, value.length() - 1);
        }
        value = value.substring(msg.lastIndexOf(" ") + 1);
        try {
            return Integer.parseInt(value);
        } catch (NumberFormatException e) {
            LOG.debug("Error parsing " + msg, e);
        }
        return resultsPerPage;
    }

    /**
     * Retrieve the latest last updated date from the list of vulnerabilities.
     *
     * @param lastUpdated the last updated date.
     * @param list the list of vulnerabilities.
     * @return the latest last modified date.
     */
    private ZonedDateTime findLastUpdated(ZonedDateTime lastUpdated, List<DefCveItem> list) {
        ZonedDateTime current = lastUpdated;
        for (DefCveItem item : list) {
            if (current == null || current.compareTo(item.getCve().getLastModified()) < 0) {
                current = item.getCve().getLastModified();
            }
        }
        return current;
    }

    @Override
    public ZonedDateTime getLastUpdated() {
        return lastUpdated;
    }

    private RateLimitedCall getCompletedFuture() throws InterruptedException, ExecutionException {
        Future<RateLimitedCall> result = null;
        while (result == null && !futures.isEmpty()) {
            for (Future<RateLimitedCall> future : futures) {
                if (future.isDone()) {
                    result = future;
                    break;
                }
            }
            if (result == null) {
                Thread.sleep(500);
            }
        }
        if (result != null) {
            futures.remove(result);
            return result.get();
        }
        return null;
    }

    private void queueUnsuccessful() {
        int clientIndex = 0;
        for (Map.Entry<Integer, Integer> i : indexesToRetrieve.entrySet()) {
            if (i.getValue() > 5) {
                throw new NvdApiRetryExceededException("NVD Update Failed: attempted to retrieve starting index "
                        + i.getKey() + " from the NVD unsuccessfully five times.");
            }
            i.setValue(i.getValue() + 1);
            futures.add(callApi(clientIndex, i.getKey()));
            clientIndex += 1;
            if (clientIndex >= clients.size()) {
                clientIndex = 0;
            }
        }
    }

    private void queueCalls() {
        int clientIndex = 0;
        int pageCount = 1;
        // start at results per page - as 0 was already requested
        for (int i = resultsPerPage; (maxPageCount <= 0 || pageCount < maxPageCount)
                && i < totalAvailable; i += resultsPerPage) {
            indexesToRetrieve.put(i, 0);
            futures.add(callApi(clientIndex, i));
            pageCount += 1;
            clientIndex += 1;
            if (clientIndex >= clients.size()) {
                clientIndex = 0;
            }
        }
    }
}
