/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Map;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.CapsuleUtils;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class CapsuleLayer
extends SameDiffLayer {
    private static final String WEIGHT_PARAM = "weight";
    private static final String BIAS_PARAM = "bias";
    private boolean hasBias = false;
    private long inputCapsules = 0L;
    private long inputCapsuleDimensions = 0L;
    private int capsules;
    private int capsuleDimensions;
    private int routings = 3;

    public CapsuleLayer(Builder builder) {
        super(builder);
        this.hasBias = builder.hasBias;
        this.inputCapsules = builder.inputCapsules;
        this.inputCapsuleDimensions = builder.inputCapsuleDimensions;
        this.capsules = builder.capsules;
        this.capsuleDimensions = builder.capsuleDimensions;
        this.routings = builder.routings;
        if (this.capsules <= 0 || this.capsuleDimensions <= 0 || this.routings <= 0) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): capsules, capsuleDimensions, and routings must be > 0.  Got: " + this.capsules + ", " + this.capsuleDimensions + ", " + this.routings);
        }
        if (this.inputCapsules < 0L || this.inputCapsuleDimensions < 0L) {
            throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + this.layerName + "\"): inputCapsules and inputCapsuleDimensions must be >= 0 if set.  Got: " + this.inputCapsules + ", " + this.inputCapsuleDimensions);
        }
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Capsule layer (layer name = \"" + this.layerName + "\"): expect RNN input.  Got: " + inputType);
        }
        if (this.inputCapsules <= 0L || this.inputCapsuleDimensions <= 0L) {
            InputType.InputTypeRecurrent ir = (InputType.InputTypeRecurrent)inputType;
            this.inputCapsules = ir.getSize();
            this.inputCapsuleDimensions = ir.getTimeSeriesLength();
        }
    }

    @Override
    public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
        SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4);
        SDVariable tiled = SD.tile(expanded, new int[]{1, 1, this.capsules * this.capsuleDimensions, 1, 1});
        SDVariable weights = paramTable.get(WEIGHT_PARAM);
        SDVariable uHat = weights.times(tiled).sum(true, new int[]{3}).reshape(new long[]{-1L, this.inputCapsules, this.capsules, this.capsuleDimensions, 1L});
        SDVariable b = SD.zerosLike(uHat).get(new SDIndex[]{SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval((Integer)0, (Integer)1), SDIndex.interval((Integer)0, (Integer)1)});
        for (int i = 0; i < this.routings; ++i) {
            SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5);
            SDVariable s = c.times(uHat).sum(true, new int[]{1});
            if (this.hasBias) {
                s = s.plus(paramTable.get(BIAS_PARAM));
            }
            SDVariable v = CapsuleUtils.squash(SD, s, 3);
            if (i == this.routings - 1) {
                return SD.squeeze(SD.squeeze(v, 1), 3);
            }
            SDVariable vTiled = SD.tile(v, new int[]{1, (int)this.inputCapsules, 1, 1, 1});
            b = b.plus(uHat.times(vTiled).sum(true, new int[]{3}));
        }
        return null;
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        params.clear();
        params.addWeightParam(WEIGHT_PARAM, 1L, this.inputCapsules, this.capsules * this.capsuleDimensions, this.inputCapsuleDimensions, 1L);
        if (this.hasBias) {
            params.addBiasParam(BIAS_PARAM, 1L, 1L, this.capsules, this.capsuleDimensions, 1L);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                if (BIAS_PARAM.equals(e.getKey())) {
                    e.getValue().assign((Number)0);
                    continue;
                }
                if (!WEIGHT_PARAM.equals(e.getKey())) continue;
                WeightInitUtil.initWeights((double)(this.inputCapsules * this.inputCapsuleDimensions), (double)(this.capsules * this.capsuleDimensions), new long[]{1L, this.inputCapsules, this.capsules * this.capsuleDimensions, this.inputCapsuleDimensions, 1L}, this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        return InputType.recurrent(this.capsules, this.capsuleDimensions);
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public long getInputCapsules() {
        return this.inputCapsules;
    }

    public long getInputCapsuleDimensions() {
        return this.inputCapsuleDimensions;
    }

    public int getCapsules() {
        return this.capsules;
    }

    public int getCapsuleDimensions() {
        return this.capsuleDimensions;
    }

    public int getRoutings() {
        return this.routings;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public void setInputCapsules(long inputCapsules) {
        this.inputCapsules = inputCapsules;
    }

    public void setInputCapsuleDimensions(long inputCapsuleDimensions) {
        this.inputCapsuleDimensions = inputCapsuleDimensions;
    }

    public void setCapsules(int capsules) {
        this.capsules = capsules;
    }

    public void setCapsuleDimensions(int capsuleDimensions) {
        this.capsuleDimensions = capsuleDimensions;
    }

    public void setRoutings(int routings) {
        this.routings = routings;
    }

    @Override
    public String toString() {
        return "CapsuleLayer(hasBias=" + this.isHasBias() + ", inputCapsules=" + this.getInputCapsules() + ", inputCapsuleDimensions=" + this.getInputCapsuleDimensions() + ", capsules=" + this.getCapsules() + ", capsuleDimensions=" + this.getCapsuleDimensions() + ", routings=" + this.getRoutings() + ")";
    }

    public CapsuleLayer() {
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof CapsuleLayer)) {
            return false;
        }
        CapsuleLayer other = (CapsuleLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.isHasBias() != other.isHasBias()) {
            return false;
        }
        if (this.getInputCapsules() != other.getInputCapsules()) {
            return false;
        }
        if (this.getInputCapsuleDimensions() != other.getInputCapsuleDimensions()) {
            return false;
        }
        if (this.getCapsules() != other.getCapsules()) {
            return false;
        }
        if (this.getCapsuleDimensions() != other.getCapsuleDimensions()) {
            return false;
        }
        return this.getRoutings() == other.getRoutings();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof CapsuleLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        long $inputCapsules = this.getInputCapsules();
        result = result * 59 + (int)($inputCapsules >>> 32 ^ $inputCapsules);
        long $inputCapsuleDimensions = this.getInputCapsuleDimensions();
        result = result * 59 + (int)($inputCapsuleDimensions >>> 32 ^ $inputCapsuleDimensions);
        result = result * 59 + this.getCapsules();
        result = result * 59 + this.getCapsuleDimensions();
        result = result * 59 + this.getRoutings();
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int capsules;
        private int capsuleDimensions;
        private int routings = 3;
        private boolean hasBias = false;
        private int inputCapsules = 0;
        private int inputCapsuleDimensions = 0;

        public Builder(int capsules, int capsuleDimensions) {
            this(capsules, capsuleDimensions, 3);
        }

        public Builder(int capsules, int capsuleDimensions, int routings) {
            this.setCapsules(capsules);
            this.setCapsuleDimensions(capsuleDimensions);
            this.setRoutings(routings);
        }

        @Override
        public <E extends Layer> E build() {
            return (E)new CapsuleLayer(this);
        }

        public Builder capsules(int capsules) {
            this.setCapsules(capsules);
            return this;
        }

        public Builder capsuleDimensions(int capsuleDimensions) {
            this.setCapsuleDimensions(capsuleDimensions);
            return this;
        }

        public Builder routings(int routings) {
            this.setRoutings(routings);
            return this;
        }

        public Builder inputCapsules(int inputCapsules) {
            this.setInputCapsules(inputCapsules);
            return this;
        }

        public Builder inputCapsuleDimensions(int inputCapsuleDimensions) {
            this.setInputCapsuleDimensions(inputCapsuleDimensions);
            return this;
        }

        public Builder inputShape(int ... inputShape) {
            int[] input = ValidationUtils.validate2NonNegative(inputShape, false, "inputShape");
            this.setInputCapsules(input[0]);
            this.setInputCapsuleDimensions(input[1]);
            return this;
        }

        public Builder hasBias(boolean hasBias) {
            this.setHasBias(hasBias);
            return this;
        }

        public int getCapsules() {
            return this.capsules;
        }

        public int getCapsuleDimensions() {
            return this.capsuleDimensions;
        }

        public int getRoutings() {
            return this.routings;
        }

        public boolean isHasBias() {
            return this.hasBias;
        }

        public int getInputCapsules() {
            return this.inputCapsules;
        }

        public int getInputCapsuleDimensions() {
            return this.inputCapsuleDimensions;
        }

        public void setCapsules(int capsules) {
            this.capsules = capsules;
        }

        public void setCapsuleDimensions(int capsuleDimensions) {
            this.capsuleDimensions = capsuleDimensions;
        }

        public void setRoutings(int routings) {
            this.routings = routings;
        }

        public void setHasBias(boolean hasBias) {
            this.hasBias = hasBias;
        }

        public void setInputCapsules(int inputCapsules) {
            this.inputCapsules = inputCapsules;
        }

        public void setInputCapsuleDimensions(int inputCapsuleDimensions) {
            this.inputCapsuleDimensions = inputCapsuleDimensions;
        }
    }
}

