/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.processor.oteltracegroup;

import com.google.common.base.Strings;
import io.micrometer.core.instrument.Counter;
import java.io.IOException;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.common.document.DocumentField;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.logging.DataPrepperMarkers;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.processor.AbstractProcessor;
import org.opensearch.dataprepper.model.processor.Processor;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.trace.DefaultTraceGroupFields;
import org.opensearch.dataprepper.model.trace.Span;
import org.opensearch.dataprepper.model.trace.TraceGroupFields;
import org.opensearch.dataprepper.plugins.processor.oteltracegroup.OTelTraceGroupProcessorConfig;
import org.opensearch.dataprepper.plugins.processor.oteltracegroup.OpenSearchClientFactory;
import org.opensearch.dataprepper.plugins.processor.oteltracegroup.model.TraceGroup;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@DataPrepperPlugin(name="otel_trace_group", pluginType=Processor.class, pluginConfigurationType=OTelTraceGroupProcessorConfig.class)
public class OTelTraceGroupProcessor
extends AbstractProcessor<Record<Span>, Record<Span>> {
    public static final String RECORDS_IN_MISSING_TRACE_GROUP = "recordsInMissingTraceGroup";
    public static final String RECORDS_OUT_FIXED_TRACE_GROUP = "recordsOutFixedTraceGroup";
    public static final String RECORDS_OUT_MISSING_TRACE_GROUP = "recordsOutMissingTraceGroup";
    private static final Logger LOG = LoggerFactory.getLogger(OTelTraceGroupProcessor.class);
    private final OTelTraceGroupProcessorConfig otelTraceGroupProcessorConfig;
    private final RestHighLevelClient restHighLevelClient;
    private final Counter recordsInMissingTraceGroupCounter;
    private final Counter recordsOutFixedTraceGroupCounter;
    private final Counter recordsOutMissingTraceGroupCounter;

    @DataPrepperPluginConstructor
    public OTelTraceGroupProcessor(PluginMetrics pluginMetrics, OTelTraceGroupProcessorConfig otelTraceGroupProcessorConfig, AwsCredentialsSupplier awsCredentialsSupplier) {
        super(pluginMetrics);
        this.otelTraceGroupProcessorConfig = otelTraceGroupProcessorConfig;
        OpenSearchClientFactory openSearchClientFactory = OpenSearchClientFactory.fromConnectionConfiguration(otelTraceGroupProcessorConfig.getEsConnectionConfig());
        this.restHighLevelClient = openSearchClientFactory.createRestHighLevelClient(awsCredentialsSupplier);
        this.recordsInMissingTraceGroupCounter = pluginMetrics.counter(RECORDS_IN_MISSING_TRACE_GROUP);
        this.recordsOutFixedTraceGroupCounter = pluginMetrics.counter(RECORDS_OUT_FIXED_TRACE_GROUP);
        this.recordsOutMissingTraceGroupCounter = pluginMetrics.counter(RECORDS_OUT_MISSING_TRACE_GROUP);
    }

    public Collection<Record<Span>> doExecute(Collection<Record<Span>> rawSpanRecords) {
        String traceId;
        LinkedList<Record<Span>> recordsOut = new LinkedList<Record<Span>>();
        HashSet<Record<Span>> recordsMissingTraceGroupInfo = new HashSet<Record<Span>>();
        HashSet<String> traceIdsToLookUp = new HashSet<String>();
        for (Record<Span> record : rawSpanRecords) {
            Span span = (Span)record.getData();
            String traceGroup = span.getTraceGroup();
            traceId = span.getTraceId();
            if (Strings.isNullOrEmpty((String)traceGroup)) {
                traceIdsToLookUp.add(traceId);
                recordsMissingTraceGroupInfo.add(record);
                this.recordsInMissingTraceGroupCounter.increment();
                continue;
            }
            recordsOut.add(record);
        }
        Map<String, TraceGroup> traceIdToTraceGroup = this.searchTraceGroupByTraceIds(traceIdsToLookUp);
        for (Record record : recordsMissingTraceGroupInfo) {
            Span span = (Span)record.getData();
            traceId = span.getTraceId();
            TraceGroup traceGroup = traceIdToTraceGroup.get(traceId);
            if (traceGroup != null) {
                try {
                    this.fillInTraceGroupInfo(span, traceGroup);
                    recordsOut.add((Record<Span>)record);
                    this.recordsOutFixedTraceGroupCounter.increment();
                }
                catch (Exception e) {
                    recordsOut.add((Record<Span>)record);
                    this.recordsOutMissingTraceGroupCounter.increment();
                    LOG.error(DataPrepperMarkers.EVENT, "Failed to process the span: [{}]", record.getData(), (Object)e);
                }
                continue;
            }
            recordsOut.add((Record<Span>)record);
            this.recordsOutMissingTraceGroupCounter.increment();
            String spanId = span.getSpanId();
            LOG.warn("Failed to find traceGroup for spanId: {} due to traceGroup missing for traceId: {}", (Object)spanId, (Object)traceId);
        }
        return recordsOut;
    }

    private void fillInTraceGroupInfo(Span span, TraceGroup traceGroup) {
        span.setTraceGroup(traceGroup.getTraceGroup());
        span.setTraceGroupFields(traceGroup.getTraceGroupFields());
    }

    private Map<String, TraceGroup> searchTraceGroupByTraceIds(Collection<String> traceIds) {
        HashMap<String, TraceGroup> traceIdToTraceGroup = new HashMap<String, TraceGroup>();
        SearchRequest searchRequest = this.createSearchRequest(traceIds);
        try {
            SearchResponse searchResponse = this.restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT);
            SearchHit[] searchHits = searchResponse.getHits().getHits();
            Arrays.asList(searchHits).forEach(searchHit -> {
                Optional<Map.Entry<String, TraceGroup>> optionalStringTraceGroupEntry = this.fromSearchHitToMapEntry((SearchHit)searchHit);
                optionalStringTraceGroupEntry.ifPresent(entry -> traceIdToTraceGroup.put((String)entry.getKey(), (TraceGroup)entry.getValue()));
            });
        }
        catch (Exception e) {
            LOG.error("Search request for traceGroup failed for traceIds: {} due to {}", traceIds, (Object)e.getMessage());
        }
        return traceIdToTraceGroup;
    }

    private SearchRequest createSearchRequest(Collection<String> traceIds) {
        SearchRequest searchRequest = new SearchRequest(new String[]{OTelTraceGroupProcessorConfig.RAW_INDEX_ALIAS});
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)QueryBuilders.boolQuery().must((QueryBuilder)QueryBuilders.termsQuery((String)"traceId", traceIds)).must((QueryBuilder)QueryBuilders.termQuery((String)"parentSpanId", (String)"")));
        searchSourceBuilder.docValueField("traceId");
        searchSourceBuilder.docValueField("traceGroup");
        searchSourceBuilder.docValueField("traceGroupFields.endTime", "strict_date_time");
        searchSourceBuilder.docValueField("traceGroupFields.durationInNanos");
        searchSourceBuilder.docValueField("traceGroupFields.statusCode");
        searchSourceBuilder.fetchSource(false);
        searchRequest.source(searchSourceBuilder);
        return searchRequest;
    }

    private Optional<Map.Entry<String, TraceGroup>> fromSearchHitToMapEntry(SearchHit searchHit) {
        DocumentField traceIdDocField = searchHit.field("traceId");
        DocumentField traceGroupNameDocField = searchHit.field("traceGroup");
        DocumentField traceGroupEndTimeDocField = searchHit.field("traceGroupFields.endTime");
        DocumentField traceGroupDurationInNanosDocField = searchHit.field("traceGroupFields.durationInNanos");
        DocumentField traceGroupStatusCodeDocField = searchHit.field("traceGroupFields.statusCode");
        if (Stream.of(traceIdDocField, traceGroupNameDocField, traceGroupEndTimeDocField, traceGroupDurationInNanosDocField, traceGroupStatusCodeDocField).allMatch(Objects::nonNull)) {
            String traceId = (String)traceIdDocField.getValue();
            String traceGroupName = (String)traceGroupNameDocField.getValue();
            String traceGroupEndTime = this.normalizeDateTime((String)traceGroupEndTimeDocField.getValue());
            Number traceGroupDurationInNanos = (Number)traceGroupDurationInNanosDocField.getValue();
            Number traceGroupStatusCode = (Number)traceGroupStatusCodeDocField.getValue();
            DefaultTraceGroupFields traceGroupFields = DefaultTraceGroupFields.builder().withEndTime(traceGroupEndTime).withDurationInNanos(Long.valueOf(traceGroupDurationInNanos.longValue())).withStatusCode(Integer.valueOf(traceGroupStatusCode.intValue())).build();
            TraceGroup traceGroup = new TraceGroup.TraceGroupBuilder().setTraceGroup(traceGroupName).setTraceGroupFields((TraceGroupFields)traceGroupFields).build();
            return Optional.of(new AbstractMap.SimpleEntry<String, TraceGroup>(traceId, traceGroup));
        }
        return Optional.empty();
    }

    private String normalizeDateTime(String dateTimeString) {
        return Instant.parse(dateTimeString).toString();
    }

    public void prepareForShutdown() {
    }

    public boolean isReadyForShutdown() {
        return true;
    }

    public void shutdown() {
        try {
            this.restHighLevelClient.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }
}

