package com.clearnlp.run;

import com.clearnlp.classification.algorithm.old.AbstractAdaGrad;
import com.clearnlp.classification.algorithm.old.AdaGradHinge;
import com.clearnlp.classification.algorithm.old.AdaGradLR;
import com.clearnlp.classification.model.AbstractModel;
import com.clearnlp.classification.train.AbstractTrainSpace;
import com.clearnlp.classification.train.SparseTrainSpace;
import com.clearnlp.classification.train.StringTrainSpace;
import com.clearnlp.propbank.frameset.PBFLib;
import com.clearnlp.util.UTInput;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import org.kohsuke.args4j.Option;

/* loaded from: input_file:com/clearnlp/run/AdaGradTrain.class */
public class AdaGradTrain extends AbstractRun {

    @Option(name = "-i", usage = "the training file (input; required)", required = true, metaVar = "<filename>")
    private String s_trainFile;

    @Option(name = "-m", usage = "the model file (output; required)", required = true, metaVar = "<filename>")
    private String s_modelFile;

    @Option(name = "-nl", usage = "label frequency cutoff (default: 0)\nexclusive, string vector space only", required = false, metaVar = "<integer>")
    private int i_labelCutoff = 0;

    @Option(name = "-nf", usage = "feature frequency cutoff (default: 0)\nexclusive, string vector space only", required = false, metaVar = "<integer>")
    private int i_featureCutoff = 0;

    @Option(name = PBFLib.EXT_VERB, usage = "the type of vector space (default: 1)\n0: sparse vector space\n1: string vector space\n", required = false, metaVar = "<byte>")
    private byte i_vectorType = 1;

    @Option(name = "-s", usage = "the type of solver (default: 3)\n3: AdaGrad using hinge loss\n4: AdaGrad using logistic regression", required = false, metaVar = "<byte>")
    private byte i_solver = 3;

    @Option(name = "-a", usage = "the cost (default: 0.01)", required = false, metaVar = "<double>")
    private double d_alpha = 0.01d;

    @Option(name = "-r", usage = "the ridge (default: 0.1)", required = false, metaVar = "<double>")
    private double d_rho = 0.1d;

    @Option(name = "-e", usage = "the terminal criterion (default: 0.05)", required = false, metaVar = "<double>")
    private double d_eps = 0.05d;

    @Option(name = "-average", usage = "if true, average wegiths", required = false, metaVar = "<boolean>")
    private boolean b_average = false;

    public AdaGradTrain() {
    }

    public AdaGradTrain(String[] strArr) {
        initArgs(strArr);
        try {
            train(this.s_trainFile, this.s_modelFile, this.i_vectorType, this.i_labelCutoff, this.i_featureCutoff, this.i_solver, this.d_alpha, this.d_rho, this.d_eps, this.b_average);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void train(String str, String str2, byte b, int i, int i2, byte b2, double d, double d2, double d3, boolean z) throws Exception {
        AbstractTrainSpace abstractTrainSpace = null;
        boolean hasWeight = AbstractTrainSpace.hasWeight(b, str);
        switch (b) {
            case 0:
                abstractTrainSpace = new SparseTrainSpace(hasWeight);
                break;
            case 1:
                abstractTrainSpace = new StringTrainSpace(hasWeight, i, i2);
                break;
        }
        abstractTrainSpace.readInstances(UTInput.createBufferedFileReader(str));
        abstractTrainSpace.build();
        AbstractModel model = getModel(abstractTrainSpace, b2, d, d2, d3, z);
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(str2)));
        objectOutputStream.writeObject(model);
        objectOutputStream.close();
    }

    public static AbstractAdaGrad getAlgorithm(byte b, double d, double d2, double d3) {
        switch (b) {
            case 3:
                return new AdaGradHinge(d, d2, d3);
            case 4:
                return new AdaGradLR(d, d2, d3);
            default:
                return null;
        }
    }

    public static AbstractModel getModel(AbstractTrainSpace abstractTrainSpace, byte b, double d, double d2, double d3, boolean z) {
        getAlgorithm(b, d, d2, d3).updateWeights(abstractTrainSpace, z);
        return abstractTrainSpace.getModel();
    }

    public static void main(String[] strArr) {
        new AdaGradTrain(strArr);
    }
}
