/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.util.ndarray;

import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.BooleanWritable;
import org.datavec.api.writable.ByteWritable;
import org.datavec.api.writable.BytesWritable;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class RecordConverter {
    private RecordConverter() {
    }

    @Deprecated
    public static INDArray toArray(Collection<Writable> record, int size) {
        return RecordConverter.toArray(record);
    }

    public static List<List<Writable>> toRecords(INDArray matrix) {
        ArrayList<List<Writable>> ret = new ArrayList<List<Writable>>();
        for (int i = 0; i < matrix.rows(); ++i) {
            ret.add(RecordConverter.toRecord(matrix.getRow((long)i)));
        }
        return ret;
    }

    public static INDArray toTensor(List<List<List<Writable>>> records) {
        return (INDArray)TimeSeriesWritableUtils.convertWritablesSequence(records).getFirst();
    }

    public static INDArray toMatrix(List<List<Writable>> records) {
        ArrayList<INDArray> toStack = new ArrayList<INDArray>();
        for (List<Writable> l : records) {
            toStack.add(RecordConverter.toArray(l));
        }
        return Nd4j.vstack(toStack);
    }

    public static INDArray toArray(Collection<? extends Writable> record) {
        ArrayList<? extends Writable> l = record instanceof List ? (ArrayList<? extends Writable>)record : new ArrayList<Writable>(record);
        if (l.size() == 1 && l.get(0) instanceof NDArrayWritable) {
            return ((NDArrayWritable)l.get(0)).get();
        }
        int length = 0;
        for (Writable writable : record) {
            if (writable instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable)writable).get();
                if (!a.isRowVector()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is not a row vector. Can only concat row vectors with other writables. Shape: " + Arrays.toString(a.shape()));
                }
                length = (int)((long)length + a.length());
                continue;
            }
            ++length;
        }
        INDArray arr = Nd4j.create((int)1, (int)length);
        boolean bl = false;
        for (Writable writable : record) {
            int n;
            if (writable instanceof NDArrayWritable) {
                INDArray toPut = ((NDArrayWritable)writable).get();
                arr.put(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)n, (long)((long)n + toPut.length()))}, toPut);
                n = (int)((long)n + toPut.length());
                continue;
            }
            arr.putScalar(0L, (long)n, writable.toDouble());
            ++n;
        }
        return arr;
    }

    public static INDArray toMinibatchArray(@NonNull List<? extends Writable> l) {
        if (l == null) {
            throw new NullPointerException("l is marked @NonNull but is null");
        }
        Preconditions.checkArgument((l.size() > 0 ? 1 : 0) != 0, (Object)"Cannot convert empty list");
        if (l.size() == 1 && l.get(0) instanceof NDArrayWritable) {
            return ((NDArrayWritable)l.get(0)).get();
        }
        ArrayList<INDArray> toConcat = null;
        DoubleArrayList list = null;
        for (Writable writable : l) {
            if (writable instanceof NDArrayWritable) {
                INDArray a = ((NDArrayWritable)writable).get();
                if (a.size(0) != 1L) {
                    throw new UnsupportedOperationException("NDArrayWritable must have leading dimension 1 for thismethod. Received array with shape: " + Arrays.toString(a.shape()));
                }
                if (toConcat == null) {
                    toConcat = new ArrayList<INDArray>();
                }
                toConcat.add(a);
                continue;
            }
            if (list == null) {
                list = new DoubleArrayList();
            }
            list.add(writable.toDouble());
        }
        if (toConcat != null && list != null) {
            throw new IllegalStateException("Error converting writables: found both NDArrayWritable and single value (DoubleWritable etc) in the one list. All writables must be NDArrayWritables or single value writables only for this method");
        }
        if (toConcat != null) {
            return Nd4j.concat((int)0, (INDArray[])toConcat.toArray(new INDArray[toConcat.size()]));
        }
        return Nd4j.create((double[])list.toArray(new double[list.size()]), (int[])new int[]{list.size(), 1});
    }

    public static List<Writable> toRecord(INDArray array) {
        ArrayList<Writable> writables = new ArrayList<Writable>();
        writables.add(new NDArrayWritable(array));
        return writables;
    }

    public static List<Writable> toRecord(Schema schema, List<Object> source) {
        ArrayList<Writable> record = new ArrayList<Writable>(source.size());
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        if (columnMetaData.size() != source.size()) {
            throw new IllegalArgumentException("Schema and source list don't have the same length!");
        }
        for (int i = 0; i < columnMetaData.size(); ++i) {
            Object data;
            ColumnMetaData metaData = columnMetaData.get(i);
            if (!metaData.isValid(data = source.get(i))) {
                throw new IllegalArgumentException("Element " + i + ": " + data + " is not valid for Column \"" + metaData.getName() + "\" (" + (Object)((Object)metaData.getColumnType()) + ")");
            }
            try {
                Writable writable;
                switch (metaData.getColumnType().getWritableType()) {
                    case Float: {
                        writable = new FloatWritable(((Float)data).floatValue());
                        break;
                    }
                    case Double: {
                        writable = new DoubleWritable((Double)data);
                        break;
                    }
                    case Int: {
                        writable = new IntWritable((Integer)data);
                        break;
                    }
                    case Byte: {
                        writable = new ByteWritable((Byte)data);
                        break;
                    }
                    case Boolean: {
                        writable = new BooleanWritable((Boolean)data);
                        break;
                    }
                    case Long: {
                        writable = new LongWritable((Long)data);
                        break;
                    }
                    case Null: {
                        writable = new NullWritable();
                        break;
                    }
                    case Bytes: {
                        writable = new BytesWritable((byte[])data);
                        break;
                    }
                    case NDArray: {
                        writable = new NDArrayWritable((INDArray)data);
                        break;
                    }
                    case Text: {
                        if (data instanceof String) {
                            writable = new Text((String)data);
                            break;
                        }
                        if (data instanceof Text) {
                            writable = new Text((Text)data);
                            break;
                        }
                        if (data instanceof byte[]) {
                            writable = new Text((byte[])data);
                            break;
                        }
                        throw new IllegalArgumentException("Element " + i + ": " + data + " is not usable for Column \"" + metaData.getName() + "\" (" + (Object)((Object)metaData.getColumnType()) + ")");
                    }
                    default: {
                        throw new IllegalArgumentException("Element " + i + ": " + data + " is not usable for Column \"" + metaData.getName() + "\" (" + (Object)((Object)metaData.getColumnType()) + ")");
                    }
                }
                record.add(writable);
                continue;
            }
            catch (ClassCastException e) {
                throw new IllegalArgumentException("Element " + i + ": " + data + " is not usable for Column \"" + metaData.getName() + "\" (" + (Object)((Object)metaData.getColumnType()) + ")", e);
            }
        }
        return record;
    }

    public static List<List<Writable>> toRecords(DataSet dataSet) {
        if (RecordConverter.isClassificationDataSet(dataSet)) {
            return RecordConverter.getClassificationWritableMatrix(dataSet);
        }
        return RecordConverter.getRegressionWritableMatrix(dataSet);
    }

    private static boolean isClassificationDataSet(DataSet dataSet) {
        INDArray labels = dataSet.getLabels();
        return labels.sum(new int[]{0, 1}).getInt(new int[]{0}) == dataSet.numExamples() && labels.shape()[1] > 1L;
    }

    private static List<List<Writable>> getClassificationWritableMatrix(DataSet dataSet) {
        ArrayList<List<Writable>> writableMatrix = new ArrayList<List<Writable>>();
        for (int i = 0; i < dataSet.numExamples(); ++i) {
            List<Writable> writables = RecordConverter.toRecord(dataSet.getFeatures().getRow((long)i));
            writables.add(new IntWritable(Nd4j.argMax((INDArray)dataSet.getLabels().getRow((long)i), (int[])new int[]{1}).getInt(new int[]{0})));
            writableMatrix.add(writables);
        }
        return writableMatrix;
    }

    private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) {
        ArrayList<List<Writable>> writableMatrix = new ArrayList<List<Writable>>();
        for (int i = 0; i < dataSet.numExamples(); ++i) {
            List<Writable> writables = RecordConverter.toRecord(dataSet.getFeatures().getRow((long)i));
            INDArray labelRow = dataSet.getLabels().getRow((long)i);
            int j = 0;
            while ((long)j < labelRow.shape()[1]) {
                writables.add(new DoubleWritable(labelRow.getDouble((long)j)));
                ++j;
            }
            writableMatrix.add(writables);
        }
        return writableMatrix;
    }
}

