/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;

public final class AirfoilRandomAccess
extends RandomAccessDataset {
    private static final String ARTIFACT_ID = "airfoil";
    private static final String[] FEATURE_ARRAY = new String[]{"freq", "aoa", "chordlen", "freestreamvel", "ssdt"};
    private Set<String> features;
    private Set<String> availableFeatures;
    private String label;
    private List<CSVRecord> csvRecords;
    private Dataset.Usage usage;
    private float[][] data;
    private float[] labelArray;
    private Map<String, Integer> stringToIndex;
    private Resource resource;
    private boolean prepared;

    private AirfoilRandomAccess(Builder builder) {
        super((RandomAccessDataset.BaseBuilder)builder);
        MRL mrl = MRL.dataset((Application)Application.Tabular.LINEAR_REGRESSION, (String)builder.groupId, (String)builder.artifactId);
        this.resource = new Resource(builder.repository, mrl, "1.0");
        this.usage = builder.usage;
        this.features = new HashSet<String>();
        this.availableFeatures = new HashSet<String>(Arrays.asList(FEATURE_ARRAY));
        this.label = "ssoundpres";
        this.stringToIndex = new HashMap<String, Integer>();
        for (int i = 0; i < FEATURE_ARRAY.length; ++i) {
            this.stringToIndex.put(FEATURE_ARRAY[i], i);
        }
        this.stringToIndex.put(this.label, FEATURE_ARRAY.length);
    }

    public void whitenAll() throws IOException, TranslateException {
        int index;
        int index2;
        this.prepare();
        float[] meanArray = new float[FEATURE_ARRAY.length + 1];
        float[] sdArray = new float[FEATURE_ARRAY.length + 1];
        for (CSVRecord record : this.csvRecords) {
            for (String feature : FEATURE_ARRAY) {
                int n = index2 = this.stringToIndex.get(feature).intValue();
                meanArray[n] = meanArray[n] + this.getRecordFloat(record, feature);
            }
            int n = index = this.stringToIndex.get(this.label).intValue();
            meanArray[n] = meanArray[n] + this.getRecordFloat(record, this.label);
        }
        int i = 0;
        while (i < meanArray.length) {
            int n = i++;
            meanArray[n] = meanArray[n] / (float)this.size();
        }
        for (CSVRecord record : this.csvRecords) {
            for (String feature : FEATURE_ARRAY) {
                int n = index2 = this.stringToIndex.get(feature).intValue();
                sdArray[n] = sdArray[n] + (float)Math.pow(this.getRecordFloat(record, feature) - meanArray[index2], 2.0);
            }
            int n = index = this.stringToIndex.get(this.label).intValue();
            sdArray[n] = sdArray[n] + (float)Math.pow(this.getRecordFloat(record, this.label) - meanArray[index], 2.0);
        }
        for (i = 0; i < sdArray.length; ++i) {
            sdArray[i] = (float)Math.sqrt(sdArray[i] / (float)this.csvRecords.size());
        }
        this.data = new float[(int)this.size()][this.getFeatureArraySize()];
        this.labelArray = new float[(int)this.size()];
        i = 0;
        while ((long)i < this.size()) {
            CSVRecord record;
            record = this.csvRecords.get(i);
            for (String feature : FEATURE_ARRAY) {
                index2 = this.stringToIndex.get(feature);
                this.data[i][index2] = (this.getRecordFloat(record, feature) - meanArray[index2]) / sdArray[index2];
            }
            this.labelArray[i] = (this.getRecordFloat(record, this.label) - meanArray[FEATURE_ARRAY.length]) / sdArray[FEATURE_ARRAY.length];
            ++i;
        }
    }

    public List<String> getFeatureOrder() {
        return new ArrayList<String>(this.features);
    }

    public float getRecordFloat(CSVRecord record, String feature) {
        return Float.parseFloat(record.get(feature));
    }

    public void selectFirstN(int n) throws IOException, TranslateException {
        this.prepare();
        this.csvRecords.subList(n, this.csvRecords.size()).clear();
    }

    public static Builder builder() {
        return new Builder();
    }

    public float[] getLabel(int index) {
        return new float[]{this.labelArray[index]};
    }

    protected Record get(NDManager manager, long index) {
        int idx = Math.toIntExact(index);
        NDList d = new NDList(new NDArray[]{this.getFeatureNDArray(manager, idx)});
        NDList l = new NDList(new NDArray[]{manager.create(this.getLabel(idx))});
        return new Record(d, l);
    }

    public CSVRecord getCSVRecord(int index) {
        return this.csvRecords.get(index);
    }

    public float[] getValueFloat(CSVRecord record, String feature) {
        return new float[]{Float.parseFloat(record.get(feature))};
    }

    public int getFeatureArraySize() {
        return this.features.size();
    }

    public NDArray getFeatureNDArray(NDManager manager, int index) {
        float[] newFeatureArray = new float[this.getFeatureArraySize()];
        int i = 0;
        for (String feature : this.features) {
            int featureIndex = this.stringToIndex.get(feature);
            newFeatureArray[i] = this.data[index][featureIndex];
            ++i;
        }
        return manager.create(newFeatureArray);
    }

    public void removeAllFeatures() {
        this.availableFeatures.addAll(this.features);
        this.features.clear();
    }

    public void addAllFeatures() {
        this.features.addAll(this.availableFeatures);
        this.availableFeatures.clear();
    }

    public void addFeature(String feature) {
        if (this.availableFeatures.contains(feature = feature.toLowerCase())) {
            this.availableFeatures.remove(feature);
            this.features.add(feature);
        }
    }

    public void removeFeature(String feature) {
        if (this.features.contains(feature = feature.toLowerCase())) {
            this.features.remove(feature);
            this.availableFeatures.add(feature);
        }
    }

    public void prepare(Progress progress) throws IOException {
        Path csvFile;
        if (this.prepared) {
            return;
        }
        Artifact artifact = this.resource.getDefaultArtifact();
        this.resource.prepare(artifact);
        Path root = this.resource.getRepository().getResourceDirectory(artifact);
        switch (this.usage) {
            case TRAIN: {
                csvFile = root.resolve("airfoil_self_noise.dat");
                break;
            }
            case TEST: {
                throw new UnsupportedOperationException("Test data not available.");
            }
            default: {
                throw new UnsupportedOperationException("Validation data not available.");
            }
        }
        String[] stringArray = null;
        try (BufferedReader reader = Files.newBufferedReader(csvFile);
             CSVParser csvParser = new CSVParser((Reader)reader, CSVFormat.TDF.withHeader(new String[]{"freq", "aoa", "chordlen", "freestreamvel", "ssdt", "ssoundpres"}).withIgnoreHeaderCase().withTrim());){
            this.csvRecords = csvParser.getRecords();
        }
        catch (Throwable object) {
            stringArray = object;
            throw object;
        }
        this.data = new float[(int)this.size()][FEATURE_ARRAY.length];
        this.labelArray = new float[(int)this.size()];
        for (int i = 0; i < this.csvRecords.size(); ++i) {
            for (String feature : FEATURE_ARRAY) {
                int featureIndex = this.stringToIndex.get(feature);
                this.data[i][featureIndex] = this.getRecordFloat(this.getCSVRecord(i), feature);
            }
            this.labelArray[i] = this.getRecordFloat(this.getCSVRecord(i), this.label);
        }
        this.prepared = true;
    }

    protected long availableSize() {
        return this.csvRecords.size();
    }

    public static final class Builder
    extends RandomAccessDataset.BaseBuilder<Builder> {
        Repository repository = BasicDatasets.REPOSITORY;
        String groupId = "ai.djl.basicdataset";
        String artifactId = "airfoil";
        Dataset.Usage usage = Dataset.Usage.TRAIN;

        Builder() {
        }

        public Builder self() {
            return this;
        }

        public Builder optUsage(Dataset.Usage usage) {
            this.usage = usage;
            return this.self();
        }

        public Builder optRepository(Repository repository) {
            this.repository = repository;
            return this.self();
        }

        public Builder optGroupId(String groupId) {
            this.groupId = groupId;
            return this;
        }

        public Builder optArtifactId(String artifactId) {
            if (artifactId.contains(":")) {
                String[] tokens = artifactId.split(":");
                this.groupId = tokens[0];
                this.artifactId = tokens[1];
            } else {
                this.artifactId = artifactId;
            }
            return this;
        }

        public AirfoilRandomAccess build() {
            return new AirfoilRandomAccess(this);
        }
    }
}

