package org.nd4j.linalg.dataset.api.preprocessor.stats;

import java.util.Arrays;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/stats/MinMaxStats.class */
public class MinMaxStats implements NormalizerStats {
    private static final Logger log = LoggerFactory.getLogger(MinMaxStats.class);
    private final INDArray lower;
    private final INDArray upper;
    private INDArray range;

    /* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/stats/MinMaxStats$Builder.class */
    public static class Builder implements NormalizerStats.Builder<MinMaxStats> {
        private INDArray runningLower;
        private INDArray runningUpper;

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: addFeatures, reason: merged with bridge method [inline-methods] */
        public NormalizerStats.Builder<MinMaxStats> addFeatures2(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return add2(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: addLabels, reason: merged with bridge method [inline-methods] */
        public NormalizerStats.Builder<MinMaxStats> addLabels2(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            return add2(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: add, reason: merged with bridge method [inline-methods] */
        public NormalizerStats.Builder<MinMaxStats> add2(@NonNull INDArray iNDArray, INDArray iNDArray2) {
            if (iNDArray == null) {
                throw new NullPointerException("data");
            }
            INDArray tailor2d = DataSetUtil.tailor2d(iNDArray, iNDArray2);
            if (tailor2d == null) {
                return this;
            }
            tailor2d.javaTensorAlongDimension(0, 0);
            INDArray min = tailor2d.min(0);
            INDArray max = tailor2d.max(0);
            if (!Arrays.equals(min.shape(), max.shape())) {
                throw new IllegalStateException("Data min and max must be same shape. Likely a bug in the operation changing the input?");
            }
            if (this.runningLower == null) {
                this.runningLower = min.dup();
                this.runningUpper = max.dup();
            } else {
                Transforms.min(this.runningLower, min, false);
                Transforms.max(this.runningUpper, max, false);
            }
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        public MinMaxStats build() {
            if (this.runningLower == null) {
                throw new RuntimeException("No data was added, statistics cannot be determined");
            }
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    MinMaxStats minMaxStats = new MinMaxStats(this.runningLower.dup(), this.runningUpper.dup());
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    return minMaxStats;
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
    }

    public MinMaxStats(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("lower");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("upper");
        }
        INDArray sub = iNDArray2.sub(iNDArray);
        INDArray subi = Transforms.max(sub, Nd4j.EPS_THRESHOLD).subi(sub);
        if (subi.sumNumber().doubleValue() > 0.0d) {
            log.info("API_INFO: max val minus min val found to be zero. Transform will round up to epsilon to avoid nans.");
            iNDArray2.addi(subi);
        }
        this.lower = iNDArray;
        this.upper = iNDArray2;
    }

    public INDArray getRange() {
        if (this.range == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                this.range = this.upper.sub(this.lower);
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        return this.range;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MinMaxStats)) {
            return false;
        }
        MinMaxStats minMaxStats = (MinMaxStats) obj;
        if (!minMaxStats.canEqual(this)) {
            return false;
        }
        INDArray lower = getLower();
        INDArray lower2 = minMaxStats.getLower();
        if (lower == null) {
            if (lower2 != null) {
                return false;
            }
        } else if (!lower.equals(lower2)) {
            return false;
        }
        INDArray upper = getUpper();
        INDArray upper2 = minMaxStats.getUpper();
        if (upper == null) {
            if (upper2 != null) {
                return false;
            }
        } else if (!upper.equals(upper2)) {
            return false;
        }
        INDArray range = getRange();
        INDArray range2 = minMaxStats.getRange();
        return range == null ? range2 == null : range.equals(range2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MinMaxStats;
    }

    public int hashCode() {
        INDArray lower = getLower();
        int hashCode = (1 * 59) + (lower == null ? 43 : lower.hashCode());
        INDArray upper = getUpper();
        int hashCode2 = (hashCode * 59) + (upper == null ? 43 : upper.hashCode());
        INDArray range = getRange();
        return (hashCode2 * 59) + (range == null ? 43 : range.hashCode());
    }

    public INDArray getLower() {
        return this.lower;
    }

    public INDArray getUpper() {
        return this.upper;
    }
}
