/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public class DataTypesSession
extends AbstractSession<DataType, DataTypeCalc> {
    protected boolean dynamicUpdate;

    public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) {
        super(sameDiff);
        this.dynamicUpdate = dynamicUpdate;
    }

    @Override
    public DataType getConstantOrVariable(String variableName) {
        DataType dt = this.sameDiff.getVariable(variableName).dataType();
        Preconditions.checkNotNull((Object)dt, (String)"No datatype available for variable %s", (Object)variableName);
        return dt;
    }

    @Override
    public DataTypeCalc getAndParameterizeOp(String opName, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
        DifferentialFunction df = this.sameDiff.getOpById(opName);
        ArrayList<DataType> inputDataTypes = new ArrayList<DataType>();
        for (SDVariable v : df.args()) {
            DataType dt = v.dataType();
            if (dt != null) {
                inputDataTypes.add(dt);
                continue;
            }
            String s = v.getVarName();
            for (AbstractSession.VarId vid : inputs) {
                if (!vid.getVariable().equals(s)) continue;
                DataType dt2 = (DataType)this.nodeOutputs.get(vid);
                Preconditions.checkNotNull((Object)dt2, (String)"No datatype for %s", (Object)vid);
                inputDataTypes.add(dt2);
            }
        }
        return new DataTypeCalc(df, inputDataTypes);
    }

    public DataType[] getOutputs(DataTypeCalc op, AbstractSession.FrameIter outputFrameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
        List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
        if (this.dynamicUpdate) {
            SDVariable[] fnOutputs = op.getFn().outputVariables();
            for (int i = 0; i < fnOutputs.length; ++i) {
                SDVariable v = fnOutputs[i];
                DataType d = outTypes.get(i);
                if (v.dataType() == d) continue;
                v.setDataType(d);
            }
        }
        return outTypes.toArray(new DataType[outTypes.size()]);
    }

    protected static class DataTypeCalc {
        protected final DifferentialFunction fn;
        protected final List<DataType> inputTypes;

        public DataTypeCalc(DifferentialFunction fn, List<DataType> inputTypes) {
            this.fn = fn;
            this.inputTypes = inputTypes;
        }

        public DifferentialFunction getFn() {
            return this.fn;
        }

        public List<DataType> getInputTypes() {
            return this.inputTypes;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof DataTypeCalc)) {
                return false;
            }
            DataTypeCalc other = (DataTypeCalc)o;
            if (!other.canEqual(this)) {
                return false;
            }
            DifferentialFunction this$fn = this.getFn();
            DifferentialFunction other$fn = other.getFn();
            if (this$fn == null ? other$fn != null : !((Object)this$fn).equals(other$fn)) {
                return false;
            }
            List<DataType> this$inputTypes = this.getInputTypes();
            List<DataType> other$inputTypes = other.getInputTypes();
            return !(this$inputTypes == null ? other$inputTypes != null : !((Object)this$inputTypes).equals(other$inputTypes));
        }

        protected boolean canEqual(Object other) {
            return other instanceof DataTypeCalc;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            DifferentialFunction $fn = this.getFn();
            result = result * 59 + ($fn == null ? 43 : ((Object)$fn).hashCode());
            List<DataType> $inputTypes = this.getInputTypes();
            result = result * 59 + ($inputTypes == null ? 43 : ((Object)$inputTypes).hashCode());
            return result;
        }

        public String toString() {
            return "DataTypesSession.DataTypeCalc(fn=" + this.getFn() + ", inputTypes=" + this.getInputTypes() + ")";
        }
    }
}

