/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.parallelconsumer.state;

import io.confluent.parallelconsumer.internal.EpochAndRecordsMap;
import io.confluent.parallelconsumer.internal.PCModule;
import io.confluent.parallelconsumer.metrics.PCMetrics;
import io.confluent.parallelconsumer.metrics.PCMetricsDef;
import io.confluent.parallelconsumer.offsets.OffsetMapCodecManager;
import io.confluent.parallelconsumer.state.PartitionState;
import io.confluent.parallelconsumer.state.RemovedPartitionState;
import io.confluent.parallelconsumer.state.ShardManager;
import io.confluent.parallelconsumer.state.WorkContainer;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Meter;
import io.micrometer.core.instrument.Tag;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.TopicPartition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PartitionStateManager<K, V>
implements ConsumerRebalanceListener {
    private static final Logger log = LoggerFactory.getLogger(PartitionStateManager.class);
    public static final double USED_PAYLOAD_THRESHOLD_MULTIPLIER_DEFAULT = 0.75;
    private static double USED_PAYLOAD_THRESHOLD_MULTIPLIER = 0.75;
    private final ShardManager<K, V> sm;
    private final Map<TopicPartition, PartitionState<K, V>> partitionStates = new ConcurrentHashMap<TopicPartition, PartitionState<K, V>>();
    private final Map<TopicPartition, Long> partitionsAssignmentEpochs = new ConcurrentHashMap<TopicPartition, Long>();
    private final PCModule<K, V> module;
    private Gauge numberOfPartitionsGauge;
    private Gauge totalIncompletesGauge;
    private final Map<TopicPartition, Counter> slowWorkCounters = new HashMap<TopicPartition, Counter>();
    private final PCMetrics pcMetrics;

    public PartitionStateManager(PCModule<K, V> module, ShardManager<K, V> sm) {
        this.sm = sm;
        this.module = module;
        this.pcMetrics = module.pcMetrics();
        this.initMetrics();
    }

    public PartitionState<K, V> getPartitionState(TopicPartition tp) {
        return this.partitionStates.get(tp);
    }

    private PartitionState<K, V> getPartitionState(EpochAndRecordsMap.RecordsAndEpoch recordsAndEpoch) {
        return this.getPartitionState(recordsAndEpoch.getTopicPartition());
    }

    protected PartitionState<K, V> getPartitionState(WorkContainer<K, V> workContainer) {
        TopicPartition topicPartition = workContainer.getTopicPartition();
        return this.getPartitionState(topicPartition);
    }

    public void onPartitionsAssigned(Collection<TopicPartition> assignedPartitions) {
        log.debug("Partitions assigned: {}", assignedPartitions);
        for (TopicPartition partitionAssignment : assignedPartitions) {
            boolean isAlreadyAssigned = this.partitionStates.containsKey(partitionAssignment);
            if (!isAlreadyAssigned) continue;
            PartitionState<K, V> previouslyAssignedState = this.partitionStates.get(partitionAssignment);
            if (previouslyAssignedState.isRemoved()) {
                log.trace("Reassignment of previously revoked partition {} - state: {}", (Object)partitionAssignment, previouslyAssignedState);
                continue;
            }
            log.warn("New assignment of partition which already exists and isn't recorded as removed in partition state. Could be a state bug - was the partition revocation somehow missed, or is this a race? Please file a GH issue. Partition: {}, state: {}", (Object)partitionAssignment, previouslyAssignedState);
        }
        this.incrementPartitionAssignmentEpoch(assignedPartitions);
        try {
            OffsetMapCodecManager<K, V> om = new OffsetMapCodecManager<K, V>(this.module);
            Map<TopicPartition, PartitionState<K, V>> partitionStates = om.loadPartitionStateForAssignment(assignedPartitions);
            this.partitionStates.putAll(partitionStates);
            this.initPartitionCounters(assignedPartitions);
            this.sm.removeStaleContainers();
        }
        catch (Exception e) {
            log.error("Error in onPartitionsAssigned", (Throwable)e);
            throw e;
        }
    }

    private void initPartitionCounters(Collection<TopicPartition> assignedPartitions) {
        assignedPartitions.forEach(topicPartition -> {
            if (!this.slowWorkCounters.containsKey(topicPartition)) {
                this.slowWorkCounters.put((TopicPartition)topicPartition, this.pcMetrics.getCounterFromMetricDef(PCMetricsDef.SLOW_RECORDS, Tag.of((String)"topic", (String)topicPartition.topic()), Tag.of((String)"partition", (String)String.valueOf(topicPartition.partition()))));
            }
        });
    }

    private void deregisterPartitionCounters(Collection<TopicPartition> removedPartitions) {
        removedPartitions.forEach(topicPartition -> {
            Counter counter = this.slowWorkCounters.remove(topicPartition);
            if (counter != null) {
                this.pcMetrics.removeMeter((Meter)counter);
            }
        });
    }

    public void incrementSlowWorkCounter(TopicPartition topicPartition) {
        Optional.ofNullable(this.slowWorkCounters.get(topicPartition)).ifPresent(Counter::increment);
    }

    public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
        log.info("Partitions revoked: {}", partitions);
        try {
            this.onPartitionsRemoved(partitions);
        }
        catch (Exception e) {
            log.error("Error in onPartitionsRevoked", (Throwable)e);
            throw e;
        }
    }

    void onPartitionsRemoved(Collection<TopicPartition> partitions) {
        this.incrementPartitionAssignmentEpoch(partitions);
        this.resetOffsetMapAndRemoveWork(partitions);
        this.deregisterPartitionCounters(partitions);
        this.sm.removeStaleContainers();
    }

    public void onPartitionsLost(Collection<TopicPartition> partitions) {
        try {
            log.info("Lost partitions: {}", partitions);
            this.onPartitionsRemoved(partitions);
        }
        catch (Exception e) {
            log.error("Error in onPartitionsLost", (Throwable)e);
            throw e;
        }
    }

    public void onOffsetCommitSuccess(Map<TopicPartition, OffsetAndMetadata> committed) {
        committed.forEach((tp, meta) -> {
            PartitionState<K, V> partition = this.getPartitionState((TopicPartition)tp);
            partition.onOffsetCommitSuccess((OffsetAndMetadata)meta);
        });
    }

    private void resetOffsetMapAndRemoveWork(Collection<TopicPartition> allRemovedPartitions) {
        for (TopicPartition removedPartition : allRemovedPartitions) {
            PartitionState<K, V> partition = this.partitionStates.get(removedPartition);
            this.partitionStates.put(removedPartition, RemovedPartitionState.getSingleton());
            partition.onPartitionsRemoved(this.sm);
        }
    }

    public Long getEpochOfPartition(TopicPartition partition) {
        return this.partitionsAssignmentEpochs.get(partition);
    }

    private void incrementPartitionAssignmentEpoch(Collection<TopicPartition> partitions) {
        for (TopicPartition partition : partitions) {
            Long epoch;
            Long l = epoch = this.partitionsAssignmentEpochs.getOrDefault(partition, -1L);
            epoch = epoch + 1L;
            this.partitionsAssignmentEpochs.put(partition, epoch);
        }
    }

    public boolean isAllowedMoreRecords(TopicPartition tp) {
        PartitionState<K, V> partitionState = this.getPartitionState(tp);
        return partitionState.isAllowedMoreRecords();
    }

    public boolean isAllowedMoreRecords(WorkContainer<?, ?> wc) {
        return this.isAllowedMoreRecords(wc.getTopicPartition());
    }

    public boolean hasIncompleteOffsets() {
        for (PartitionState<K, V> partition : this.getAssignedPartitions().values()) {
            if (!partition.hasIncompleteOffsets()) continue;
            return true;
        }
        return false;
    }

    public long getNumberOfIncompleteOffsets() {
        Collection<PartitionState<K, V>> values = this.getAssignedPartitions().values();
        return values.stream().mapToLong(PartitionState::getNumberOfIncompleteOffsets).reduce(Long::sum).orElse(0L);
    }

    public long getHighestSeenOffset(TopicPartition tp) {
        return this.getPartitionState(tp).getOffsetHighestSeen();
    }

    public void onSuccess(WorkContainer<K, V> wc) {
        PartitionState<K, V> partitionState = this.getPartitionState(wc.getTopicPartition());
        partitionState.onSuccess(wc.offset());
    }

    public void onFailure(WorkContainer<K, V> wc) {
        PartitionState<K, V> partitionState = this.getPartitionState(wc.getTopicPartition());
        partitionState.onFailure(wc);
    }

    void maybeRegisterNewRecordAsWork(EpochAndRecordsMap<K, V> recordsMap) {
        log.debug("Incoming {} new records...", (Object)recordsMap.count());
        for (EpochAndRecordsMap.RecordsAndEpoch recordsAndEpoch : recordsMap.getRecordMap().values()) {
            PartitionState<K, V> partitionState = this.getPartitionState(recordsAndEpoch);
            partitionState.maybeRegisterNewPollBatchAsWork(recordsAndEpoch);
        }
    }

    public Map<TopicPartition, OffsetAndMetadata> collectDirtyCommitData() {
        HashMap<TopicPartition, OffsetAndMetadata> dirties = new HashMap<TopicPartition, OffsetAndMetadata>();
        for (PartitionState state : this.getAssignedPartitions().values()) {
            Optional<OffsetAndMetadata> offsetAndMetadata = state.getCommitDataIfDirty();
            offsetAndMetadata.ifPresent(andMetadata -> dirties.put(state.getTp(), (OffsetAndMetadata)andMetadata));
        }
        return dirties;
    }

    private Map<TopicPartition, PartitionState<K, V>> getAssignedPartitions() {
        return Collections.unmodifiableMap(this.partitionStates.entrySet().stream().filter(e -> !((PartitionState)e.getValue()).isRemoved()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
    }

    public boolean couldBeTakenAsWork(WorkContainer<K, V> workContainer) {
        return this.getPartitionState(workContainer).couldBeTakenAsWork(workContainer);
    }

    public boolean isDirty() {
        return this.partitionStates.values().stream().anyMatch(PartitionState::isDirty);
    }

    private void initMetrics() {
        this.numberOfPartitionsGauge = this.pcMetrics.gaugeFromMetricDef(PCMetricsDef.NUMBER_OF_PARTITIONS, this, pm -> this.getAssignedPartitions().size(), new Tag[0]);
        this.totalIncompletesGauge = this.pcMetrics.gaugeFromMetricDef(PCMetricsDef.INCOMPLETE_OFFSETS_TOTAL, this, partitionStateManager -> partitionStateManager.getAssignedPartitions().values().stream().mapToInt(PartitionState::getNumberOfIncompleteOffsets).sum(), new Tag[0]);
    }

    public static double getUSED_PAYLOAD_THRESHOLD_MULTIPLIER() {
        return USED_PAYLOAD_THRESHOLD_MULTIPLIER;
    }

    public static void setUSED_PAYLOAD_THRESHOLD_MULTIPLIER(double USED_PAYLOAD_THRESHOLD_MULTIPLIER) {
        PartitionStateManager.USED_PAYLOAD_THRESHOLD_MULTIPLIER = USED_PAYLOAD_THRESHOLD_MULTIPLIER;
    }
}

