/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class PlotFilters {
    private INDArray plot;
    private INDArray input;
    private int[] tileShape;
    private int[] tileSpacing = new int[]{0, 0};
    private int[] imageShape;
    private boolean scaleRowsToInterval = true;
    private boolean outputPixels = true;

    public PlotFilters(INDArray input, int[] tileShape, int[] tileSpacing, int[] imageShape) {
        this.input = input;
        this.tileShape = tileShape;
        this.tileSpacing = tileSpacing;
        this.imageShape = imageShape;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setInput(INDArray input) {
        this.input = input;
    }

    public INDArray scale(INDArray toScale) {
        return toScale.sub(toScale.min(new int[]{Integer.MAX_VALUE})).muli((Number)(1.0 / (Nd4j.EPS_THRESHOLD + toScale.max(new int[]{Integer.MAX_VALUE}).getDouble(0))));
    }

    public void plot() {
        int[] retShape = new int[]{(this.imageShape[0] + this.tileSpacing[0]) * this.tileShape[0] - this.tileSpacing[0], (this.imageShape[1] + this.tileSpacing[1]) * this.tileShape[1] - this.tileSpacing[1]};
        if (this.input.rank() == 2) {
            this.plot = this.plotSection(this.input, retShape);
        } else {
            this.plot = Nd4j.zeros((int[])new int[]{retShape[0], retShape[1], 4});
            for (int i = 0; i < 4; ++i) {
                INDArray retSection = this.plotSection(this.input.slice(i), retShape);
                this.plot.putSlice(i, retSection);
            }
        }
    }

    public INDArray getPlot() {
        if (this.plot == null) {
            throw new IllegalStateException("Please call plot() first.");
        }
        return this.plot;
    }

    private INDArray plotSection(INDArray input, int[] retShape) {
        INDArray ret = Nd4j.zeros((int[])retShape);
        if (input.getLeadingOnes() == 2) {
            input = input.reshape(input.size(-2), input.size(-1));
        }
        int h = this.imageShape[0];
        int w = this.imageShape[1];
        int hs = this.tileSpacing[0];
        int ws = this.tileSpacing[1];
        for (int tileRow = 0; tileRow < this.tileShape[0]; ++tileRow) {
            for (int tileCol = 0; tileCol < this.tileShape[1]; ++tileCol) {
                if (tileRow * this.tileShape[1] + tileCol >= input.size(0)) continue;
                INDArray image = input.get(new INDArrayIndex[]{NDArrayIndex.point((int)(tileRow * this.tileShape[1] + tileCol))});
                image = image.reshape(this.imageShape);
                if (this.scaleRowsToInterval) {
                    image = this.scale(image);
                }
                if (this.outputPixels) {
                    image.muli((Number)255);
                }
                int rowBegin = tileRow * (h + hs);
                int rowEnd = tileRow * (h + hs) + h;
                int colBegin = tileCol * (w + ws);
                int colEnd = tileCol * (w + ws) + w;
                INDArrayIndex rowIndex = NDArrayIndex.interval((int)rowBegin, (int)rowEnd);
                INDArrayIndex colIndex = NDArrayIndex.interval((int)colBegin, (int)colEnd);
                ret.put(new INDArrayIndex[]{rowIndex, colIndex}, image);
            }
        }
        return ret;
    }
}

