/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.types.Shape;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;

public class NDList
extends ArrayList<NDArray>
implements NDResource {
    private static final long serialVersionUID = 1L;

    public NDList() {
    }

    public NDList(int initialCapacity) {
        super(initialCapacity);
    }

    public NDList(NDArray ... arrays) {
        super(Arrays.asList(arrays));
    }

    public NDList(Collection<NDArray> other) {
        super(other);
    }

    public static NDList decode(NDManager manager, byte[] byteArray) {
        return NDList.decode(manager, new ByteArrayInputStream(byteArray));
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static NDList decode(NDManager manager, InputStream is) {
        try (DataInputStream dis = new DataInputStream(is);){
            int size = dis.readInt();
            if (size < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + size);
            }
            NDList list = new NDList();
            for (int i = 0; i < size; ++i) {
                list.add(i, manager.decode(dis));
            }
            NDList nDList = list;
            return nDList;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Malformed data", e);
        }
    }

    public NDArray remove(String name) {
        int index = 0;
        for (NDArray array : this) {
            if (name.equals(array.getName())) {
                this.remove(index);
                return array;
            }
            ++index;
        }
        return null;
    }

    public boolean contains(String name) {
        for (NDArray array : this) {
            if (!name.equals(array.getName())) continue;
            return true;
        }
        return false;
    }

    public NDArray head() {
        return (NDArray)this.get(0);
    }

    public NDArray singletonOrThrow() {
        if (this.size() != 1) {
            throw new IndexOutOfBoundsException("Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + this.size());
        }
        return (NDArray)this.get(0);
    }

    public NDList addAll(NDList other) {
        for (NDArray array : other) {
            this.add(array);
        }
        return this;
    }

    public NDList subNDList(int fromIndex) {
        return new NDList((Collection<NDArray>)this.subList(fromIndex, this.size()));
    }

    public NDList toDevice(Device device, boolean copy) {
        if (!copy && this.stream().allMatch(array -> array.getDevice() == device)) {
            return this;
        }
        NDList newNDList = new NDList(this.size());
        this.forEach((? super E a) -> newNDList.add(a.toDevice(device, copy)));
        return newNDList;
    }

    @Override
    public NDManager getManager() {
        return this.head().getManager();
    }

    @Override
    public void attach(NDManager manager) {
        this.stream().forEach((? super T array) -> array.attach(manager));
    }

    @Override
    public void tempAttach(NDManager manager) {
        this.stream().forEach((? super T array) -> array.tempAttach(manager));
    }

    @Override
    public void detach() {
        this.stream().forEach(NDResource::detach);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public byte[] encode() {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();){
            DataOutputStream dos = new DataOutputStream(baos);
            dos.writeInt(this.size());
            for (NDArray nd : this) {
                dos.write(nd.encode());
            }
            dos.flush();
            Object object = baos.toByteArray();
            return object;
        }
        catch (IOException e) {
            throw new AssertionError("NDList is not writable", e);
        }
    }

    public Shape[] getShapes() {
        return (Shape[])this.stream().map(NDArray::getShape).toArray(Shape[]::new);
    }

    @Override
    public void close() {
        this.forEach(NDArray::close);
        this.clear();
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder(200);
        builder.append("NDList size: ").append(this.size()).append('\n');
        int index = 0;
        for (NDArray array : this) {
            String name = array.getName();
            builder.append(index++).append(' ');
            if (name != null) {
                builder.append(name);
            }
            builder.append(": ").append(array.getShape()).append(' ').append((Object)array.getDataType()).append('\n');
        }
        return builder.toString();
    }
}

