/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation;

import com.google.common.util.concurrent.AtomicDouble;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.MessageHandler;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.NDArrayCompressor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EncodingHandler
implements MessageHandler {
    private static final Logger log = LoggerFactory.getLogger(EncodingHandler.class);
    public static final long THRESHOLD_LOG_FREQ_MS = 10000L;
    protected transient GradientsAccumulator accumulator;
    protected ThresholdAlgorithm initialThresholdAlgorithm;
    protected ResidualPostProcessor initialResidualPostProcessor;
    protected Double boundary;
    protected boolean encodingDebugMode;
    protected NDArrayCompressor compressor;
    protected AtomicInteger atomicBoundary = new AtomicInteger(-1);
    protected ThreadLocal<ThresholdAlgorithm> thresholdAlgorithm = new ThreadLocal();
    protected Map<Long, ThresholdAlgorithm> allThreadThresholdAlgorithms = new ConcurrentHashMap<Long, ThresholdAlgorithm>();
    protected ThreadLocal<ResidualPostProcessor> residualPostProcessor = new ThreadLocal();
    protected ThreadLocal<AtomicLong> iterations = new ThreadLocal();
    protected ThreadLocal<AtomicLong> lastStep = new ThreadLocal();
    protected ThreadLocal<AtomicDouble> lastThreshold = new ThreadLocal();
    protected ThreadLocal<AtomicDouble> lastSparsityRatio = new ThreadLocal();
    protected ThreadLocal<AtomicDouble> currentThreshold = new ThreadLocal();
    protected ThreadLocal<AtomicBoolean> bitmapMode = new ThreadLocal();
    protected ThreadLocal<AtomicBoolean> lastIterWasDense = new ThreadLocal();
    protected AtomicLong lastThresholdLogTime = new AtomicLong();
    protected static ThreadLocal<DecimalFormat> formatter = new ThreadLocal();
    protected static ThreadLocal<DecimalFormat> formatter2 = new ThreadLocal();

    public EncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Double boundary, boolean encodingDebugMode) {
        this.initialThresholdAlgorithm = thresholdAlgorithm;
        this.initialResidualPostProcessor = residualPostProcessor;
        this.boundary = boundary;
        this.encodingDebugMode = encodingDebugMode;
    }

    @Override
    public void initialize(@NonNull GradientsAccumulator accumulator) {
        if (accumulator == null) {
            throw new NullPointerException("accumulator is marked @NonNull but is null");
        }
        this.accumulator = accumulator;
        this.compressor = Nd4j.getCompressor().getCompressor("THRESHOLD");
        if (this.compressor == null) {
            throw new ND4JIllegalStateException("Can't find Threshold compressor implementation!");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public INDArray encodeUpdates(int iteration, int epoch, INDArray updates) {
        INDArray encoded;
        if (this.thresholdAlgorithm.get() == null) {
            EncodingHandler encodingHandler = this;
            synchronized (encodingHandler) {
                this.thresholdAlgorithm.set(this.initialThresholdAlgorithm.clone());
                this.allThreadThresholdAlgorithms.put(Thread.currentThread().getId(), this.thresholdAlgorithm.get());
                if (this.initialResidualPostProcessor != null) {
                    this.residualPostProcessor.set(this.initialResidualPostProcessor.clone());
                }
            }
        }
        Double lastThr = null;
        Boolean lastWasDense = null;
        Double lastSparsity = null;
        if (this.lastThreshold.get() != null) {
            lastThr = this.lastThreshold.get().get();
            lastWasDense = this.lastIterWasDense.get().get();
            lastSparsity = lastWasDense != false || this.lastSparsityRatio.get() == null ? null : Double.valueOf(this.lastSparsityRatio.get().get());
        }
        double currThreshold = this.thresholdAlgorithm.get().calculateThreshold(iteration, epoch, lastThr, lastWasDense, lastSparsity, updates);
        if (this.bitmapMode.get() == null) {
            this.bitmapMode.set(new AtomicBoolean(true));
            this.currentThreshold.set(new AtomicDouble(currThreshold));
            this.iterations.set(new AtomicLong(0L));
            this.lastStep.set(new AtomicLong(0L));
            this.lastThreshold.set(new AtomicDouble(currThreshold));
            this.lastIterWasDense.set(new AtomicBoolean());
        }
        this.currentThreshold.get().set(currThreshold);
        this.lastThreshold.get().set(currThreshold);
        this.residualDebugOutputIfRequired(updates);
        this.iterations.get().incrementAndGet();
        if (this.boundary != null && this.atomicBoundary.get() < 0) {
            this.atomicBoundary.compareAndSet(-1, (int)((double)updates.lengthLong() * this.boundary));
        }
        if (!this.bitmapMode.get().get()) {
            encoded = Nd4j.getExecutioner().thresholdEncode(updates, this.currentThreshold.get().get(), this.boundary == null ? null : Integer.valueOf(this.atomicBoundary.get()));
            if (encoded == null) {
                this.bitmapMode.get().set(false);
                if (this.lastSparsityRatio.get() == null) {
                    this.lastSparsityRatio.set(new AtomicDouble(0.0));
                } else {
                    this.lastSparsityRatio.get().set(0.0);
                }
                this.lastIterWasDense.get().set(false);
                this.logThresholdIfReq(false, iteration, epoch);
                return null;
            }
            double encLen = encoded.data().getInt(0L);
            if (encLen >= (double)(updates.lengthLong() / 16L)) {
                log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", new Object[]{iteration, epoch, currThreshold, encLen});
                this.bitmapMode.get().set(true);
                DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16L + 5L);
                encoded = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)updates.shapeInfoDataBuffer());
                Nd4j.getExecutioner().bitmapEncode(updates, encoded, this.currentThreshold.get().get());
                this.applyPostProcessor(iteration, epoch, currThreshold, updates);
                this.lastSparsityRatio.set(null);
                this.lastIterWasDense.get().set(true);
                this.logThresholdIfReq(true, iteration, epoch);
                return encoded;
            }
            double sparsityRatio = encLen / (double)updates.length();
            if (this.lastSparsityRatio.get() == null) {
                this.lastSparsityRatio.set(new AtomicDouble(sparsityRatio));
            } else {
                this.lastSparsityRatio.get().set(sparsityRatio);
            }
            this.lastIterWasDense.get().set(false);
        } else {
            DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.lengthLong() / 16L + 5L);
            encoded = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)updates.shapeInfoDataBuffer());
            long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, this.currentThreshold.get().get());
            if (values < (updates.lengthLong() / 16L + 5L) / 2L) {
                this.bitmapMode.get().set(false);
                log.debug("Switched to threshold encoding: iteration {}, epoch {}, threshold {}, number of values {}", new Object[]{iteration, epoch, currThreshold, values});
            }
            this.lastSparsityRatio.set(null);
            this.lastIterWasDense.get().set(true);
        }
        this.applyPostProcessor(iteration, epoch, currThreshold, updates);
        this.logThresholdIfReq(this.lastIterWasDense.get().get(), iteration, epoch);
        return encoded;
    }

    public void applyPostProcessor(int iteration, int epoch, Double lastThreshold, INDArray residuals) {
        if (this.initialResidualPostProcessor == null) {
            return;
        }
        this.residualPostProcessor.get().processResidual(iteration, epoch, lastThreshold, residuals);
    }

    @Deprecated
    public INDArray decodeUpdates(INDArray message) {
        throw new UnsupportedOperationException();
    }

    protected void sendMessage(INDArray message, int iterationNumber, int epochNumber) {
        this.accumulator.receiveUpdate(message);
    }

    @Override
    public boolean broadcastUpdates(INDArray updates, int iterationNumber, int epochNumber) {
        INDArray message = this.encodeUpdates(iterationNumber, epochNumber, updates);
        if (message != null) {
            this.sendMessage(message, iterationNumber, epochNumber);
            return true;
        }
        return false;
    }

    protected void logThresholdIfReq(boolean denseUpdates, int iter, int epoch) {
        long now = System.currentTimeMillis();
        long lastLog = this.lastThresholdLogTime.get();
        if (lastLog + 10000L <= now && this.lastThresholdLogTime.compareAndSet(lastLog, now)) {
            String lastThresholdStr = EncodingHandler.format(this.lastThreshold.get().get());
            if (denseUpdates) {
                log.info("Threshold at iter {}, epoch {} [thread {}]: {}, DENSE updates", new Object[]{iter, epoch, Thread.currentThread().getId(), lastThresholdStr});
            } else {
                AtomicDouble d = this.lastSparsityRatio.get();
                String lastSparsityStr = d == null ? "-" : EncodingHandler.format(d.get());
                log.info("Threshold at iter {}, epoch {}: {}, SPARSE updates, last sparsity ratio: {}", new Object[]{iter, epoch, Thread.currentThread().getId(), lastThresholdStr, lastSparsityStr});
            }
        }
    }

    protected void residualDebugOutputIfRequired(INDArray residual) {
        if (!this.encodingDebugMode) {
            return;
        }
        double currThreshold = this.currentThreshold.get().get();
        String currThresholdStr = EncodingHandler.format(currThreshold);
        INDArray absResidual = Transforms.abs((INDArray)residual, (boolean)true);
        double dAmean = absResidual.meanNumber().doubleValue();
        double dAMax = absResidual.maxNumber().doubleValue();
        double dPc50 = absResidual.percentileNumber((Number)50).doubleValue();
        double dPc95 = absResidual.percentileNumber((Number)95).doubleValue();
        double dPc99 = absResidual.percentileNumber((Number)99).doubleValue();
        double dPc999 = absResidual.percentileNumber((Number)99.9).doubleValue();
        double dPc9999 = absResidual.percentileNumber((Number)99.99).doubleValue();
        String amean = EncodingHandler.format(dAmean).replace('E', 'e');
        String aMax = EncodingHandler.format(dAMax).replace('E', 'e');
        String pc50 = EncodingHandler.format(dPc50).replace('E', 'e');
        String pc95 = EncodingHandler.format(dPc95).replace('E', 'e');
        String pc99 = EncodingHandler.format(dPc99).replace('E', 'e');
        String pc999 = EncodingHandler.format(dPc999).replace('E', 'e');
        String pc9999 = EncodingHandler.format(dPc9999).replace('E', 'e');
        String ameanThr = EncodingHandler.format(dAmean / currThreshold).replace('E', 'e');
        String aMaxThr = EncodingHandler.format(dAMax / currThreshold).replace('E', 'e');
        String pc50Thr = EncodingHandler.format(dPc50 / currThreshold).replace('E', 'e');
        String pc95Thr = EncodingHandler.format(dPc95 / currThreshold).replace('E', 'e');
        String pc99Thr = EncodingHandler.format(dPc99 / currThreshold).replace('E', 'e');
        String pc999Thr = EncodingHandler.format(dPc999 / currThreshold).replace('E', 'e');
        String pc9999Thr = EncodingHandler.format(dPc9999 / currThreshold).replace('E', 'e');
        long length = absResidual.length();
        long countAbsGTEThreshold = absResidual.gte((Number)currThreshold).sumNumber().longValue();
        double sparsity = (double)countAbsGTEThreshold / (double)length;
        String sparsityStr = EncodingHandler.format(sparsity);
        log.info("Encoding debug info, residual vector: length: {}, threshold: {}, count > thr: {}, sparsity: {}, amean: {} ({}x); amax: {} ({}x); 50%: {} ({}x); 95%: {} ({}x}; 99%: {} ({}x);  99.9%: {} ({}x); 99.99%: {} ({}x)", new Object[]{length, currThresholdStr, countAbsGTEThreshold, sparsityStr, amean, ameanThr, aMax, aMaxThr, pc50, pc50Thr, pc95, pc95Thr, pc99, pc99Thr, pc999, pc999Thr, pc9999, pc9999Thr});
    }

    protected static String format(double d) {
        if (d == 0.0) {
            return "0.0";
        }
        if (d <= -0.1 && d > -100.0 || d >= 0.1 && d < 100.0) {
            if (formatter2.get() == null) {
                formatter2.set(new DecimalFormat("0.###"));
            }
            return formatter2.get().format(d);
        }
        if (formatter.get() == null) {
            formatter.set(new DecimalFormat("0.###E0"));
        }
        DecimalFormat df = formatter.get();
        return df.format(d).replace('E', 'e');
    }

    public ThresholdAlgorithm getAverageThresholdAlgorithm() {
        ThresholdAlgorithm ta;
        Collection<ThresholdAlgorithm> c = this.allThreadThresholdAlgorithms.values();
        if (c.isEmpty()) {
            return null;
        }
        if (c.size() == 1) {
            return c.iterator().next();
        }
        Iterator<ThresholdAlgorithm> iter = c.iterator();
        ThresholdAlgorithmReducer r = null;
        while (iter.hasNext()) {
            ta = iter.next();
            if (r == null) {
                r = ta.newReducer();
            }
            r.add(ta);
        }
        ta = r.getFinalResult();
        this.thresholdAlgorithm = new ThreadLocal();
        this.allThreadThresholdAlgorithms.clear();
        return ta;
    }
}

