/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.annotation.svm;

import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;
import org.openimaj.ml.annotation.utils.AnnotatedListHelper;
import org.openimaj.util.array.ArrayUtils;

public class SVMAnnotator<OBJECT, ANNOTATION>
extends BatchAnnotator<OBJECT, ANNOTATION> {
    public static final int POSITIVE_CLASS = 1;
    public static final int NEGATIVE_CLASS = -1;
    public HashMap<Integer, ANNOTATION> classMap = new HashMap();
    private svm_model model = null;
    private FeatureExtractor<? extends FeatureVector, OBJECT> extractor = null;
    private File saveModel = null;

    public SVMAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor) {
        this.extractor = extractor;
    }

    @Override
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
        if (this.checkInputDataOK(data)) {
            svm_parameter param = SVMAnnotator.getDefaultSVMParameters();
            svm_problem prob = this.getSVMProblem(data, param, this.extractor);
            this.model = svm.svm_train((svm_problem)prob, (svm_parameter)param);
            if (this.saveModel != null) {
                try {
                    svm.svm_save_model((String)this.saveModel.getAbsolutePath(), (svm_model)this.model);
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private boolean checkInputDataOK(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
        this.classMap.clear();
        int i = 0;
        for (Annotated<OBJECT, ANNOTATION> x : data) {
            Collection<ANNOTATION> anns = x.getAnnotations();
            if (anns.size() != 1) {
                throw new IllegalArgumentException("Data contained an object with more than one annotation");
            }
            ANNOTATION onlyAnnotation = anns.iterator().next();
            if (this.classMap.values().contains(onlyAnnotation)) continue;
            int key = i * 2 - 1;
            ++i;
            this.classMap.put(key, onlyAnnotation);
        }
        if (this.classMap.keySet().size() != 2) {
            throw new IllegalArgumentException("Data did not contain exactly 2 classes. It had " + this.classMap.keySet().size() + ". They were " + this.classMap);
        }
        return true;
    }

    @Override
    public Set<ANNOTATION> getAnnotations() {
        HashSet<ANNOTATION> hs = new HashSet<ANNOTATION>();
        hs.addAll(this.classMap.values());
        return hs;
    }

    @Override
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        svm_node[] nodes = SVMAnnotator.featureToNode((FeatureVector)this.extractor.extractFeature(object));
        double x = svm.svm_predict((svm_model)this.model, (svm_node[])nodes);
        return Collections.singletonList(new ScoredAnnotation<ANNOTATION>(x > 0.0 ? this.classMap.get(1) : this.classMap.get(-1), 1.0f));
    }

    public void setSaveModel(File saveModel) {
        this.saveModel = saveModel;
    }

    public void loadModel(File loadModel) throws IOException {
        this.model = svm.svm_load_model((String)loadModel.getAbsolutePath());
    }

    public double crossValidation(List<? extends Annotated<OBJECT, ANNOTATION>> data, int numFold) {
        svm_parameter param = SVMAnnotator.getDefaultSVMParameters();
        svm_problem prob = this.getSVMProblem(data, param, this.extractor);
        return SVMAnnotator.crossValidation(prob, param, numFold);
    }

    public static double crossValidation(svm_problem prob, svm_parameter param, int numFold) {
        double[] target = new double[prob.l];
        svm.svm_cross_validation((svm_problem)prob, (svm_parameter)param, (int)numFold, (double[])target);
        int totalCorrect = 0;
        for (int i = 0; i < prob.l; ++i) {
            if (target[i] != prob.y[i]) continue;
            ++totalCorrect;
        }
        double accuracy = 100.0 * (double)totalCorrect / (double)prob.l;
        System.out.print("Cross Validation Accuracy = " + accuracy + "%\n");
        return accuracy;
    }

    private static svm_parameter getDefaultSVMParameters() {
        svm_parameter param = new svm_parameter();
        param.svm_type = 0;
        param.kernel_type = 2;
        param.degree = 3;
        param.gamma = 0.0;
        param.coef0 = 0.0;
        param.nu = 0.5;
        param.cache_size = 100.0;
        param.C = 1.0;
        param.eps = 0.001;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];
        return param;
    }

    private svm_problem getSVMProblem(List<? extends Annotated<OBJECT, ANNOTATION>> data, svm_parameter param, FeatureExtractor<? extends FeatureVector, OBJECT> extractor) {
        svm_node[][] positiveNodes = this.computeFeature(data, this.classMap.get(1));
        svm_node[][] negativeNodes = this.computeFeature(data, this.classMap.get(-1));
        int nSamples = positiveNodes.length + negativeNodes.length;
        double[] flagArray = new double[nSamples];
        ArrayUtils.fill((double[])flagArray, (double)1.0, (int)0, (int)positiveNodes.length);
        ArrayUtils.fill((double[])flagArray, (double)-1.0, (int)positiveNodes.length, (int)negativeNodes.length);
        svm_node[][] sampleArray = (svm_node[][])ArrayUtils.concatenate((Object[][])new svm_node[][][]{positiveNodes, negativeNodes});
        svm_problem prob = new svm_problem();
        prob.l = nSamples;
        prob.x = sampleArray;
        prob.y = flagArray;
        param.gamma = 1.0 / (double)SVMAnnotator.getMaxIndex(sampleArray);
        return prob;
    }

    private svm_node[][] computeFeature(List<? extends Annotated<OBJECT, ANNOTATION>> data, ANNOTATION annotation) {
        AnnotatedListHelper<OBJECT, ANNOTATION> alh = new AnnotatedListHelper<OBJECT, ANNOTATION>(data);
        List<? extends FeatureVector> f = alh.extractFeatures(annotation, this.extractor);
        svm_node[][] n = new svm_node[f.size()][];
        int i = 0;
        for (FeatureVector featureVector : f) {
            n[i++] = SVMAnnotator.featureToNode(featureVector);
        }
        return n;
    }

    private static int getMaxIndex(svm_node[][] sampleArray) {
        int max = 0;
        for (svm_node[] x : sampleArray) {
            for (int j = 0; j < x.length; ++j) {
                max = Math.max(max, x[j].index);
            }
        }
        return max;
    }

    private static svm_node[] featureToNode(FeatureVector featureVector) {
        double[] fv = featureVector.asDoubleVector();
        svm_node[] nodes = new svm_node[fv.length];
        for (int i = 0; i < fv.length; ++i) {
            nodes[i] = new svm_node();
            nodes[i].index = i;
            nodes[i].value = fv[i];
        }
        return nodes;
    }
}

