/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.serde;

import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.serde.LegacyOpMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
import org.nd4j.graph.IntPair;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class FlatBuffersMapper {
    private static final boolean[] EMPTY_BOOLEAN = new boolean[0];
    private static final int[] EMPTY_INT = new int[0];
    private static final long[] EMPTY_LONG = new long[0];
    private static final double[] EMPTY_DOUBLE = new double[0];

    private FlatBuffersMapper() {
    }

    public static byte getDataTypeAsByte(DataBuffer.Type type) {
        switch (type) {
            case FLOAT: {
                return 5;
            }
            case DOUBLE: {
                return 6;
            }
            case HALF: {
                return 3;
            }
            case INT: {
                return 9;
            }
            case LONG: {
                return 10;
            }
        }
        throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + "]");
    }

    public static DataBuffer.Type getDataTypeFromByte(byte val) {
        if (val == 5) {
            return DataBuffer.Type.FLOAT;
        }
        if (val == 6) {
            return DataBuffer.Type.DOUBLE;
        }
        if (val == 3) {
            return DataBuffer.Type.HALF;
        }
        throw new UnsupportedOperationException("Unsupported DataType: [" + val + "]");
    }

    public static long getOpNum(String name, Op.Type type) {
        if (type == Op.Type.LOOP) {
            return 0L;
        }
        if (type == Op.Type.RETURN) {
            return 40L;
        }
        if (type == Op.Type.IF) {
            return 30L;
        }
        if (type == Op.Type.CONDITIONAL) {
            return 10L;
        }
        if (type == Op.Type.MERGE) {
            return 60L;
        }
        if (type == Op.Type.LOOP_COND) {
            return 70L;
        }
        if (type == Op.Type.NEXT_ITERATION) {
            return 80L;
        }
        if (type == Op.Type.EXIT) {
            return 90L;
        }
        if (type == Op.Type.ENTER) {
            return 100L;
        }
        if (type == Op.Type.CUSTOM) {
            CustomOpDescriptor name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
            if (name2 == null) {
                CustomOpDescriptor name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
                if (name3 == null) {
                    return 0L;
                }
                return name3.getHash();
            }
            return name2.getHash();
        }
        return Nd4j.getOpFactory().getOpNumByName(name);
    }

    public static Op.Type getTypeFromByte(byte type) {
        switch (type) {
            case 3: {
                return Op.Type.SCALAR;
            }
            case 4: {
                return Op.Type.BROADCAST;
            }
            case 0: {
                return Op.Type.TRANSFORM;
            }
            case 1: {
                return Op.Type.REDUCE;
            }
            case 6: {
                return Op.Type.REDUCE3;
            }
            case 2: {
                return Op.Type.INDEXREDUCE;
            }
            case 10: {
                return Op.Type.RANDOM;
            }
            case 119: {
                return Op.Type.META;
            }
            case 11: {
                return Op.Type.CUSTOM;
            }
            case 8: {
                return Op.Type.SHAPE;
            }
            case 5: {
                return Op.Type.PAIRWISE;
            }
            case 7: {
                return Op.Type.SUMMARYSTATS;
            }
        }
        throw new UnsupportedOperationException("Unknown op type passed in: " + type);
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR: {
                return 3;
            }
            case BROADCAST: {
                return 4;
            }
            case TRANSFORM: 
            case SPECIAL: {
                return 0;
            }
            case REDUCE: {
                return 1;
            }
            case REDUCE3: {
                return 6;
            }
            case INDEXREDUCE: {
                return 2;
            }
            case RANDOM: {
                return 10;
            }
            case VARIANCE: {
                return 7;
            }
            case MERGE: 
            case CONDITIONAL: 
            case LOOP: 
            case RETURN: 
            case ENTER: 
            case EXIT: 
            case NEXT_ITERATION: 
            case LOOP_COND: 
            case IF: {
                return 119;
            }
            case CUSTOM: {
                return 11;
            }
            case SHAPE: {
                return 8;
            }
            case PAIRWISE: {
                return 5;
            }
            case SUMMARYSTATS: {
                return 7;
            }
        }
        throw new UnsupportedOperationException("Unknown op type passed in: " + (Object)((Object)type));
    }

    public static ByteOrder getOrderFromByte(byte val) {
        if (val == 0) {
            return ByteOrder.LITTLE_ENDIAN;
        }
        return ByteOrder.BIG_ENDIAN;
    }

    public static byte getOrderAsByte() {
        if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
            return 1;
        }
        return 0;
    }

    public static DifferentialFunction fromFlatNode(FlatNode fn) {
        Op op;
        int id = fn.id();
        String name = fn.name();
        Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType());
        long opNum = fn.opNum();
        int[] input = new int[fn.inputLength()];
        for (int i = 0; i < input.length; ++i) {
            input[i] = fn.input(i);
        }
        IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
        for (int i = 0; i < inputPaired.length; ++i) {
            inputPaired[i] = fn.inputPaired(i);
        }
        int[] output = new int[fn.outputLength()];
        for (int i = 0; i < output.length; ++i) {
            output[i] = fn.output(i);
        }
        double[] extraParams = new double[fn.extraParamsLength()];
        for (int i = 0; i < extraParams.length; ++i) {
            extraParams[i] = fn.extraParams(i);
        }
        long[] extraInteger = new long[fn.extraIntegerLength()];
        for (int i = 0; i < extraInteger.length; ++i) {
            extraInteger[i] = fn.extraInteger(i);
        }
        int[] dimensions = new int[fn.dimensionsLength()];
        for (int i = 0; i < dimensions.length; ++i) {
            dimensions[i] = fn.dimensions(i);
        }
        float scalar = fn.scalar();
        FlatProperties[] flatProperties = new FlatProperties[fn.propertiesLength()];
        for (int i = 0; i < flatProperties.length; ++i) {
            flatProperties[i] = fn.properties(i);
        }
        Map<String, Object> props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));
        if (opType == Op.Type.CUSTOM) {
            DifferentialFunction op2;
            String opName = fn.opName();
            Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);
            Preconditions.checkNotNull(c, (String)"Could not find class for hash %s", (long)opNum);
            try {
                op2 = (DifferentialFunction)c.newInstance();
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException("Error creating differential function instance of type " + c);
            }
            op2.setOwnName(name);
            ((CustomOp)((Object)op2)).addIArgument(extraInteger);
            ((CustomOp)((Object)op2)).addTArgument(extraParams);
            op2.setPropertiesForFunction(props);
            return op2;
        }
        Class<?> c = LegacyOpMapper.getLegacyOpClassForId(opType, (int)opNum);
        try {
            op = (Op)c.newInstance();
        }
        catch (IllegalAccessException | InstantiationException e) {
            throw new RuntimeException("Error creating differential function (Op) instance of type " + c);
        }
        if (extraParams.length > 0) {
            Object[] extraParamsObj = new Object[extraParams.length];
            for (int i = 0; i < extraParams.length; ++i) {
                extraParamsObj[i] = extraParams[i];
            }
            op.setExtraArgs(extraParamsObj);
        }
        if (opType == Op.Type.SCALAR) {
            ScalarOp sOp = (ScalarOp)op;
            sOp.setScalar(Float.valueOf(scalar));
        } else if (opType == Op.Type.REDUCE || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE) {
            BaseAccumulation ba = (BaseAccumulation)op;
            ba.setDimensions(dimensions);
            ba.setNewFormat(true);
        } else if (opType == Op.Type.INDEXREDUCE) {
            BaseIndexAccumulation bia = (BaseIndexAccumulation)op;
            bia.setDimensions(dimensions);
            bia.setNewFormat(true);
        }
        ((DifferentialFunction)((Object)op)).setPropertiesForFunction(props);
        return (DifferentialFunction)((Object)op);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map<String, Object> fnProps) {
        int[] outIdxs = new int[fnProps.size()];
        int count = 0;
        for (Map.Entry<String, Object> e : fnProps.entrySet()) {
            Object v = e.getValue();
            int iname = fbb.createString((CharSequence)e.getKey());
            int[] i = null;
            long[] l = null;
            double[] d = null;
            int[] aIdx = null;
            boolean[] b = null;
            int[] sIdx = null;
            int[] shape = null;
            if (v != null) {
                if (v instanceof Boolean) {
                    b = new boolean[]{(Boolean)v};
                } else if (v instanceof Number) {
                    if (v instanceof Double) {
                        d = new double[]{(Double)v};
                    } else if (v instanceof Integer) {
                        i = new int[]{(Integer)v};
                    } else {
                        if (!(v instanceof Long)) throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                        l = new long[]{(Long)v};
                    }
                } else if (v instanceof String) {
                    String str = (String)v;
                    int strOffset = fbb.createString((CharSequence)str);
                    sIdx = new int[]{strOffset};
                } else if (v instanceof INDArray) {
                    INDArray arr = (INDArray)v;
                    aIdx = new int[]{arr.toFlatArray(fbb)};
                } else if (v.getClass().isArray()) {
                    int j;
                    if (v.getClass().getComponentType().isPrimitive()) {
                        if (v instanceof boolean[]) {
                            b = (boolean[])v;
                            shape = new int[]{b.length};
                        } else if (v instanceof double[]) {
                            d = (double[])v;
                            shape = new int[]{d.length};
                        } else if (v instanceof int[]) {
                            i = (int[])v;
                            shape = new int[]{i.length};
                        } else {
                            if (!(v instanceof long[])) throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                            l = (long[])v;
                            shape = new int[]{l.length};
                        }
                    } else if (v instanceof String[]) {
                        String[] strArr = (String[])v;
                        sIdx = new int[strArr.length];
                        for (j = 0; j < strArr.length; ++j) {
                            sIdx[j] = fbb.createString((CharSequence)strArr[j]);
                        }
                        shape = new int[]{strArr.length};
                    } else if (v instanceof INDArray[]) {
                        INDArray[] arrArr = (INDArray[])v;
                        aIdx = new int[arrArr.length];
                        for (j = 0; j < arrArr.length; ++j) {
                            aIdx[j] = arrArr[j].toFlatArray(fbb);
                        }
                    } else if (v.getClass().getComponentType().isArray()) {
                        shape = ArrayUtil.arrayShape((Object)v, (boolean)true);
                        if (v instanceof boolean[][]) {
                            b = ArrayUtil.flatten((boolean[][])((boolean[][])v));
                        } else if (v instanceof boolean[][][]) {
                            b = ArrayUtil.flatten((boolean[][][])((boolean[][][])v));
                        } else if (v instanceof double[][]) {
                            d = ArrayUtil.flatten((double[][])((double[][])v));
                        } else if (v instanceof double[][][]) {
                            d = ArrayUtil.flatten((double[][][])((double[][][])v));
                        } else if (v instanceof int[][]) {
                            i = ArrayUtil.flatten((int[][])((int[][])v));
                        } else if (v instanceof int[][][]) {
                            i = ArrayUtil.flatten((int[][][])((int[][][])v));
                        } else if (v instanceof long[][]) {
                            l = ArrayUtil.flatten((long[][])((long[][])v));
                        } else {
                            if (!(v instanceof long[][][])) throw new UnsupportedOperationException("Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v.getClass());
                            l = ArrayUtil.flatten((long[][][])((long[][][])v));
                        }
                    }
                }
            }
            int idxD = FlatProperties.createDVector(fbb, d != null ? d : EMPTY_DOUBLE);
            int idxI = FlatProperties.createIVector(fbb, i != null ? i : EMPTY_INT);
            int idxL = FlatProperties.createLVector(fbb, l != null ? l : EMPTY_LONG);
            int idxA = FlatProperties.createAVector(fbb, aIdx != null ? aIdx : EMPTY_INT);
            int idxB = FlatProperties.createBVector(fbb, b != null ? b : EMPTY_BOOLEAN);
            int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT);
            int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT);
            outIdxs[count++] = FlatProperties.createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape);
        }
        return outIdxs;
    }

    public static Map<String, Object> mapFlatPropertiesToFunctionProperties(Iterable<FlatProperties> list) {
        HashMap<String, Object> out = new HashMap<String, Object>();
        for (FlatProperties p : list) {
            String name = p.name();
            if (p.shapeLength() > 0) {
                int i;
                int[] shape = new int[p.shapeLength()];
                for (int i2 = 0; i2 < shape.length; ++i2) {
                    shape[i2] = p.shape(i2);
                }
                if (p.iLength() > 0) {
                    int[] iArr = new int[p.iLength()];
                    for (i = 0; i < iArr.length; ++i) {
                        iArr[i] = p.i(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, iArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeInt((int[])iArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeInt((int[])iArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.dLength() > 0) {
                    double[] dArr = new double[p.dLength()];
                    for (i = 0; i < dArr.length; ++i) {
                        dArr[i] = p.d(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, dArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeDouble((double[])dArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeDouble((double[])dArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.lLength() > 0) {
                    long[] lArr = new long[p.lLength()];
                    for (i = 0; i < lArr.length; ++i) {
                        lArr[i] = p.l(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, lArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeLong((long[])lArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeLong((long[])lArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.bLength() > 0) {
                    boolean[] bArr = new boolean[p.bLength()];
                    for (i = 0; i < bArr.length; ++i) {
                        bArr[i] = p.b(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, bArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeBoolean((boolean[])bArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeBoolean((boolean[])bArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.sLength() > 0) {
                    Object[] sArr = new String[p.sLength()];
                    for (i = 0; i < sArr.length; ++i) {
                        sArr[i] = p.s(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, sArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeObject((Object[])sArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeObject((Object[])sArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.aLength() > 0) {
                    Object[] iArr = new INDArray[p.aLength()];
                    for (i = 0; i < iArr.length; ++i) {
                        FlatArray fa = p.a(0);
                        iArr[i] = Nd4j.createFromFlatArray(fa);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, iArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeObject((Object[])iArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeObject((Object[])iArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                out.put(name, null);
                continue;
            }
            if (p.bLength() > 0) {
                out.put(name, p.b(0));
                continue;
            }
            if (p.iLength() > 0) {
                out.put(name, p.i(0));
                continue;
            }
            if (p.lLength() > 0) {
                out.put(name, p.l(0));
                continue;
            }
            if (p.dLength() > 0) {
                out.put(name, p.d(0));
                continue;
            }
            if (p.sLength() > 0) {
                out.put(name, p.s(0));
                continue;
            }
            if (p.aLength() > 0) {
                FlatArray fa = p.a(0);
                out.put(name, Nd4j.createFromFlatArray(fa));
                continue;
            }
            out.put(name, null);
        }
        return out;
    }
}

