/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.api.Trainable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;

public class UpdaterBlock {
    private int paramOffsetStart;
    private int paramOffsetEnd;
    private int updaterViewOffsetStart;
    private int updaterViewOffsetEnd;
    private List<ParamState> layersAndVariablesInBlock = new ArrayList<ParamState>();
    private INDArray updaterView;
    private INDArray gradientView;
    private boolean updaterViewRequiresInitialization;
    private GradientUpdater gradientUpdater;

    public UpdaterBlock(int paramOffsetStart, int paramOffsetEnd, int updaterViewOffsetStart, int updaterViewOffsetEnd, List<ParamState> layersAndVariablesInBlock) {
        this.paramOffsetStart = paramOffsetStart;
        this.paramOffsetEnd = paramOffsetEnd;
        this.updaterViewOffsetStart = updaterViewOffsetStart;
        this.updaterViewOffsetEnd = updaterViewOffsetEnd;
        this.layersAndVariablesInBlock = layersAndVariablesInBlock;
    }

    public void init() {
        if (this.gradientUpdater == null) {
            ParamState varState = this.layersAndVariablesInBlock.get(0);
            String varName = varState.getParamName();
            this.gradientUpdater = varState.getLayer().getConfig().getUpdaterByParam(varName).instantiate(this.updaterView, this.updaterViewRequiresInitialization);
        }
    }

    public boolean isPretrainUpdaterBlock() {
        ParamState vs = this.layersAndVariablesInBlock.get(0);
        return vs.getLayer().getConfig().isPretrainParam(vs.getParamName());
    }

    public boolean skipDueToPretrainConfig(boolean isLayerUpdater) {
        if (!this.isPretrainUpdaterBlock()) {
            return false;
        }
        return !isLayerUpdater;
    }

    public GradientUpdater getGradientUpdater() {
        if (this.gradientUpdater == null) {
            this.init();
        }
        return this.gradientUpdater;
    }

    public void update(int iteration, int epoch) {
        this.update(iteration, epoch, false, this.gradientView, null);
    }

    public void updateExternalGradient(int iteration, int epoch, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) {
        this.update(iteration, epoch, true, fullNetworkGradientView, fullNetworkParamsArray);
    }

    private void update(int iteration, int epoch, boolean externalGradient, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) {
        if (this.gradientUpdater == null) {
            this.init();
        }
        INDArray blockGradViewArray = externalGradient ? fullNetworkGradientView.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)0L, (boolean)true), NDArrayIndex.interval((int)this.paramOffsetStart, (int)this.paramOffsetEnd)}) : this.gradientView;
        Trainable l0 = this.layersAndVariablesInBlock.get(0).getLayer();
        if (l0.numParams() == 0L) {
            return;
        }
        this.applyRegularizationAllVariables(Regularization.ApplyStep.BEFORE_UPDATER, iteration, epoch, externalGradient, fullNetworkGradientView, fullNetworkParamsArray);
        this.gradientUpdater.applyUpdater(blockGradViewArray, iteration, epoch);
        this.applyRegularizationAllVariables(Regularization.ApplyStep.POST_UPDATER, iteration, epoch, externalGradient, fullNetworkGradientView, fullNetworkParamsArray);
    }

    protected void applyRegularizationAllVariables(Regularization.ApplyStep applyStep, int iteration, int epoch, boolean externalGradient, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) {
        for (ParamState p : this.layersAndVariablesInBlock) {
            INDArray gradView;
            INDArray paramView;
            if (externalGradient) {
                paramView = fullNetworkParamsArray.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((int)p.getParamOffsetStart(), (int)p.getParamOffsetEnd())});
                gradView = fullNetworkGradientView.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((int)p.getParamOffsetStart(), (int)p.getParamOffsetEnd())});
            } else {
                paramView = p.getParamView();
                gradView = p.getGradView();
            }
            boolean hasLR = this.gradientUpdater.getConfig().hasLearningRate();
            double lr = hasLR ? this.gradientUpdater.getConfig().getLearningRate(iteration, epoch) : 1.0;
            this.applyRegularization(applyStep, p.getLayer(), p.getParamName(), gradView, paramView, iteration, epoch, lr);
        }
    }

    protected void applyRegularization(Regularization.ApplyStep step, Trainable layer, String paramName, INDArray gradientView, INDArray paramsView, int iter, int epoch, double lr) {
        List<Regularization> l = layer.getConfig().getRegularizationByParam(paramName);
        if (l != null && !l.isEmpty()) {
            for (Regularization r : l) {
                if (r.applyStep() != step) continue;
                r.apply(paramsView, gradientView, lr, iter, epoch);
            }
        }
    }

    public int getParamOffsetStart() {
        return this.paramOffsetStart;
    }

    public int getParamOffsetEnd() {
        return this.paramOffsetEnd;
    }

    public int getUpdaterViewOffsetStart() {
        return this.updaterViewOffsetStart;
    }

    public int getUpdaterViewOffsetEnd() {
        return this.updaterViewOffsetEnd;
    }

    public List<ParamState> getLayersAndVariablesInBlock() {
        return this.layersAndVariablesInBlock;
    }

    public INDArray getUpdaterView() {
        return this.updaterView;
    }

    public INDArray getGradientView() {
        return this.gradientView;
    }

    public boolean isUpdaterViewRequiresInitialization() {
        return this.updaterViewRequiresInitialization;
    }

    public void setParamOffsetStart(int paramOffsetStart) {
        this.paramOffsetStart = paramOffsetStart;
    }

    public void setParamOffsetEnd(int paramOffsetEnd) {
        this.paramOffsetEnd = paramOffsetEnd;
    }

    public void setUpdaterViewOffsetStart(int updaterViewOffsetStart) {
        this.updaterViewOffsetStart = updaterViewOffsetStart;
    }

    public void setUpdaterViewOffsetEnd(int updaterViewOffsetEnd) {
        this.updaterViewOffsetEnd = updaterViewOffsetEnd;
    }

    public void setLayersAndVariablesInBlock(List<ParamState> layersAndVariablesInBlock) {
        this.layersAndVariablesInBlock = layersAndVariablesInBlock;
    }

    public void setUpdaterView(INDArray updaterView) {
        this.updaterView = updaterView;
    }

    public void setGradientView(INDArray gradientView) {
        this.gradientView = gradientView;
    }

    public void setUpdaterViewRequiresInitialization(boolean updaterViewRequiresInitialization) {
        this.updaterViewRequiresInitialization = updaterViewRequiresInitialization;
    }

    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        this.gradientUpdater = gradientUpdater;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof UpdaterBlock)) {
            return false;
        }
        UpdaterBlock other = (UpdaterBlock)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getParamOffsetStart() != other.getParamOffsetStart()) {
            return false;
        }
        if (this.getParamOffsetEnd() != other.getParamOffsetEnd()) {
            return false;
        }
        if (this.getUpdaterViewOffsetStart() != other.getUpdaterViewOffsetStart()) {
            return false;
        }
        if (this.getUpdaterViewOffsetEnd() != other.getUpdaterViewOffsetEnd()) {
            return false;
        }
        if (this.isUpdaterViewRequiresInitialization() != other.isUpdaterViewRequiresInitialization()) {
            return false;
        }
        List<ParamState> this$layersAndVariablesInBlock = this.getLayersAndVariablesInBlock();
        List<ParamState> other$layersAndVariablesInBlock = other.getLayersAndVariablesInBlock();
        if (this$layersAndVariablesInBlock == null ? other$layersAndVariablesInBlock != null : !((Object)this$layersAndVariablesInBlock).equals(other$layersAndVariablesInBlock)) {
            return false;
        }
        INDArray this$updaterView = this.getUpdaterView();
        INDArray other$updaterView = other.getUpdaterView();
        if (this$updaterView == null ? other$updaterView != null : !this$updaterView.equals(other$updaterView)) {
            return false;
        }
        INDArray this$gradientView = this.getGradientView();
        INDArray other$gradientView = other.getGradientView();
        if (this$gradientView == null ? other$gradientView != null : !this$gradientView.equals(other$gradientView)) {
            return false;
        }
        GradientUpdater this$gradientUpdater = this.getGradientUpdater();
        GradientUpdater other$gradientUpdater = other.getGradientUpdater();
        return !(this$gradientUpdater == null ? other$gradientUpdater != null : !this$gradientUpdater.equals(other$gradientUpdater));
    }

    protected boolean canEqual(Object other) {
        return other instanceof UpdaterBlock;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getParamOffsetStart();
        result = result * 59 + this.getParamOffsetEnd();
        result = result * 59 + this.getUpdaterViewOffsetStart();
        result = result * 59 + this.getUpdaterViewOffsetEnd();
        result = result * 59 + (this.isUpdaterViewRequiresInitialization() ? 79 : 97);
        List<ParamState> $layersAndVariablesInBlock = this.getLayersAndVariablesInBlock();
        result = result * 59 + ($layersAndVariablesInBlock == null ? 43 : ((Object)$layersAndVariablesInBlock).hashCode());
        INDArray $updaterView = this.getUpdaterView();
        result = result * 59 + ($updaterView == null ? 43 : $updaterView.hashCode());
        INDArray $gradientView = this.getGradientView();
        result = result * 59 + ($gradientView == null ? 43 : $gradientView.hashCode());
        GradientUpdater $gradientUpdater = this.getGradientUpdater();
        result = result * 59 + ($gradientUpdater == null ? 43 : $gradientUpdater.hashCode());
        return result;
    }

    public String toString() {
        return "UpdaterBlock(paramOffsetStart=" + this.getParamOffsetStart() + ", paramOffsetEnd=" + this.getParamOffsetEnd() + ", updaterViewOffsetStart=" + this.getUpdaterViewOffsetStart() + ", updaterViewOffsetEnd=" + this.getUpdaterViewOffsetEnd() + ", layersAndVariablesInBlock=" + this.getLayersAndVariablesInBlock() + ", updaterView=" + this.getUpdaterView() + ", gradientView=" + this.getGradientView() + ", updaterViewRequiresInitialization=" + this.isUpdaterViewRequiresInitialization() + ", gradientUpdater=" + this.getGradientUpdater() + ")";
    }

    public static class ParamState {
        private final Trainable layer;
        private final String paramName;
        private final int paramOffsetStart;
        private final int paramOffsetEnd;
        private final INDArray paramView;
        private final INDArray gradView;

        public ParamState(Trainable layer, String paramName, int paramOffsetStart, int paramOffsetEnd, INDArray paramView, INDArray gradView) {
            this.layer = layer;
            this.paramName = paramName;
            this.paramOffsetStart = paramOffsetStart;
            this.paramOffsetEnd = paramOffsetEnd;
            this.paramView = paramView;
            this.gradView = gradView;
        }

        public Trainable getLayer() {
            return this.layer;
        }

        public String getParamName() {
            return this.paramName;
        }

        public int getParamOffsetStart() {
            return this.paramOffsetStart;
        }

        public int getParamOffsetEnd() {
            return this.paramOffsetEnd;
        }

        public INDArray getParamView() {
            return this.paramView;
        }

        public INDArray getGradView() {
            return this.gradView;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ParamState)) {
                return false;
            }
            ParamState other = (ParamState)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getParamOffsetStart() != other.getParamOffsetStart()) {
                return false;
            }
            if (this.getParamOffsetEnd() != other.getParamOffsetEnd()) {
                return false;
            }
            Trainable this$layer = this.getLayer();
            Trainable other$layer = other.getLayer();
            if (this$layer == null ? other$layer != null : !this$layer.equals(other$layer)) {
                return false;
            }
            String this$paramName = this.getParamName();
            String other$paramName = other.getParamName();
            if (this$paramName == null ? other$paramName != null : !this$paramName.equals(other$paramName)) {
                return false;
            }
            INDArray this$paramView = this.getParamView();
            INDArray other$paramView = other.getParamView();
            if (this$paramView == null ? other$paramView != null : !this$paramView.equals(other$paramView)) {
                return false;
            }
            INDArray this$gradView = this.getGradView();
            INDArray other$gradView = other.getGradView();
            return !(this$gradView == null ? other$gradView != null : !this$gradView.equals(other$gradView));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ParamState;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getParamOffsetStart();
            result = result * 59 + this.getParamOffsetEnd();
            Trainable $layer = this.getLayer();
            result = result * 59 + ($layer == null ? 43 : $layer.hashCode());
            String $paramName = this.getParamName();
            result = result * 59 + ($paramName == null ? 43 : $paramName.hashCode());
            INDArray $paramView = this.getParamView();
            result = result * 59 + ($paramView == null ? 43 : $paramView.hashCode());
            INDArray $gradView = this.getGradView();
            result = result * 59 + ($gradView == null ? 43 : $gradView.hashCode());
            return result;
        }

        public String toString() {
            return "UpdaterBlock.ParamState(layer=" + this.getLayer() + ", paramName=" + this.getParamName() + ", paramOffsetStart=" + this.getParamOffsetStart() + ", paramOffsetEnd=" + this.getParamOffsetEnd() + ", paramView=" + this.getParamView() + ", gradView=" + this.getGradView() + ")";
        }
    }
}

