/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.DFunction;
import deepboof.Tensor;
import deepboof.backward.DSpatialPadding2D;
import deepboof.forward.ConfigSpatial;
import deepboof.impl.forward.standard.SpatialWindowChannel;
import deepboof.misc.TensorFactory;
import deepboof.misc.TensorOps;
import java.util.List;

public abstract class DSpatialWindowChannel<T extends Tensor<T>, P extends DSpatialPadding2D<T>>
extends SpatialWindowChannel<T, P>
implements DFunction<T> {
    protected boolean learningMode = false;
    protected T dpadding;

    public DSpatialWindowChannel(ConfigSpatial config, P padding) {
        super(config, padding);
        this.dpadding = new TensorFactory(padding.getTensorType()).create(new int[0]);
    }

    public void backwards(T input, T dout, T gradientInput, List<T> gradientParameters) {
        if (this.shapeInput == null) {
            throw new IllegalArgumentException("Must initialize first!");
        }
        TensorOps.checkShape((String)"input", (int)-1, (int[])this.shapeInput, (int[])input.getShape(), (boolean)true);
        TensorOps.checkShape((String)"dout", (int)-1, (int[])this.shapeOutput, (int[])dout.getShape(), (boolean)true);
        TensorOps.checkShape((String)"gradientInput", (int)-1, (int[])this.shapeInput, (int[])gradientInput.getShape(), (boolean)true);
        TensorOps.checkShape((String)"gradientParameters", (List)this.shapeParameters, gradientParameters, (boolean)false);
        this._backwards(input, dout, gradientInput, gradientParameters);
    }

    protected abstract void _backwards(T var1, T var2, T var3, List<T> var4);

    public void backwardsChannel(T input, T gradientInput) {
        ((DSpatialPadding2D)this.padding).setInput((Tensor)input);
        int[] paddingShape = ((DSpatialPadding2D)this.padding).getShape();
        this.dpadding.reshape(paddingShape[2], paddingShape[3]);
        this.N = input.length(0);
        int paddingX0 = ((DSpatialPadding2D)this.padding).getPaddingCol0();
        int paddingY0 = ((DSpatialPadding2D)this.padding).getPaddingRow0();
        int outC0 = DSpatialWindowChannel.innerLowerExtent((int)this.config.periodX, (int)paddingX0);
        int outC1 = DSpatialWindowChannel.innerUpperExtent((int)this.config.WW, (int)this.config.periodX, (int)paddingX0, (int)this.W);
        int outR0 = DSpatialWindowChannel.innerLowerExtent((int)this.config.periodY, (int)paddingY0);
        int outR1 = DSpatialWindowChannel.innerUpperExtent((int)this.config.HH, (int)this.config.periodY, (int)paddingY0, (int)this.H);
        if (this.isEntirelyBorder(outR0, outC0)) {
            for (int batchIndex = 0; batchIndex < this.N; ++batchIndex) {
                for (int channel = 0; channel < this.C; ++channel) {
                    this.dpadding.zero();
                    this.backwardsBorder(batchIndex, channel, 0, 0, this.Ho, this.Wo);
                    ((DSpatialPadding2D)this.padding).backwardsChannel(this.dpadding, batchIndex, channel, gradientInput);
                }
            }
        } else {
            for (int batchIndex = 0; batchIndex < this.N; ++batchIndex) {
                for (int channel = 0; channel < this.C; ++channel) {
                    this.dpadding.zero();
                    for (int outRow = outR0; outRow < outR1; ++outRow) {
                        int inputRow = outRow * this.config.periodY - paddingY0;
                        for (int outCol = outC0; outCol < outC1; ++outCol) {
                            int inputCol = outCol * this.config.periodX - paddingX0;
                            this.backwardsAt_inner(input, batchIndex, channel, inputRow, inputCol, outRow, outCol);
                        }
                    }
                    this.backwardsBorder(batchIndex, channel, 0, 0, outR0, this.Wo);
                    this.backwardsBorder(batchIndex, channel, outR1, 0, this.Ho, this.Wo);
                    this.backwardsBorder(batchIndex, channel, outR0, 0, outR1, outC0);
                    this.backwardsBorder(batchIndex, channel, outR0, outC1, outR1, this.Wo);
                    ((DSpatialPadding2D)this.padding).backwardsChannel(this.dpadding, batchIndex, channel, gradientInput);
                }
            }
        }
    }

    private void backwardsBorder(int batchIndex, int channel, int row0, int col0, int row1, int col1) {
        for (int outRow = row0; outRow < row1; ++outRow) {
            int padRow = outRow * this.config.periodY;
            for (int outCol = col0; outCol < col1; ++outCol) {
                int padCol = outCol * this.config.periodX;
                this.backwardsAt_border((DSpatialPadding2D)this.padding, batchIndex, channel, padRow, padCol, outRow, outCol);
            }
        }
    }

    protected abstract void backwardsAt_inner(T var1, int var2, int var3, int var4, int var5, int var6, int var7);

    protected abstract void backwardsAt_border(P var1, int var2, int var3, int var4, int var5, int var6, int var7);

    public void learning() {
        this.learningMode = true;
    }

    public void evaluating() {
        this.learningMode = false;
    }

    public boolean isLearning() {
        return this.learningMode;
    }

    public Class<T> getTensorType() {
        return ((DSpatialPadding2D)this.padding).getTensorType();
    }
}

