/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.zero.tabular;

import ai.djl.Model;
import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.TabNetRegressionLoss;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;
import java.io.IOException;

public final class TabularRegression {
    private TabularRegression() {
    }

    public static ZooModel<NDList, NDList> train(TabularDataset dataset, Performance performance) throws IOException, TranslateException {
        RandomAccessDataset[] splitDataset = dataset.randomSplit(new int[]{8, 2});
        RandomAccessDataset trainDataset = splitDataset[0];
        RandomAccessDataset validateDataset = splitDataset[1];
        int featureSize = dataset.getFeatureSize();
        int labelSize = dataset.getLabelSize();
        Block block = performance.equals((Object)Performance.FAST) ? TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).optNumIndependent(1).optNumShared(1).build() : (performance.equals((Object)Performance.BALANCED) ? TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).build() : TabNet.builder().setInputDim(featureSize).setOutDim(labelSize).optNumIndependent(4).optNumShared(4).build());
        Model model = Model.newInstance((String)"tabular");
        model.setBlock(block);
        DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig((Loss)new TabNetRegressionLoss()).addTrainingListeners(TrainingListener.Defaults.basic());
        try (Trainer trainer = model.newTrainer((TrainingConfig)trainingConfig);){
            trainer.initialize(new Shape[]{new Shape(new long[]{1L, featureSize})});
            EasyTrain.fit((Trainer)trainer, (int)20, (Dataset)trainDataset, (Dataset)validateDataset);
        }
        return new ZooModel(model, (Translator)new NoopTranslator());
    }
}

