/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner.perceptron;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.openimaj.ml.linear.kernel.VectorKernel;
import org.openimaj.ml.linear.learner.perceptron.KernelPerceptron;
import org.openimaj.ml.linear.learner.perceptron.PerceptronClass;
import org.openimaj.util.pair.IndependentPair;

public class MatrixKernelPerceptron
extends KernelPerceptron<double[], PerceptronClass> {
    protected List<double[]> supports = new ArrayList<double[]>();
    protected List<Double> weights = new ArrayList<Double>();
    Map<WrappedDouble, Integer> index = new HashMap<WrappedDouble, Integer>();

    public MatrixKernelPerceptron(VectorKernel k) {
        super(k);
    }

    public double[] correct(double[] in) {
        return (double[])in.clone();
    }

    protected double mapping(double[] in) {
        double ret = this.getBias();
        in = this.correct(in);
        for (int i = 0; i < this.supports.size(); ++i) {
            double alpha = this.weights.get(i);
            double[] x_i = this.correct(this.supports.get(i));
            ret += alpha * (Double)this.kernel.apply(IndependentPair.pair((Object)x_i, (Object)in));
        }
        return ret;
    }

    @Override
    public PerceptronClass predict(double[] x) {
        return PerceptronClass.fromSign(Math.signum(this.mapping(x)));
    }

    @Override
    public void update(double[] xt, PerceptronClass yt, PerceptronClass yt_prime) {
        WrappedDouble d = new WrappedDouble(xt);
        double updateAmount = this.getUpdateRate() * (double)yt.v();
        if (!this.index.containsKey(d)) {
            this.index.put(d, this.supports.size());
            this.supports.add(xt);
            this.weights.add(updateAmount);
        } else {
            int index = this.index.get(d);
            this.weights.set(index, this.weights.get(index) + updateAmount);
        }
    }

    double getUpdateRate() {
        return 1.0;
    }

    @Override
    public List<double[]> getSupports() {
        return this.supports;
    }

    @Override
    public List<Double> getWeights() {
        return this.weights;
    }

    @Override
    public double getBias() {
        double bias = 0.0;
        for (double d : this.weights) {
            bias += d;
        }
        return bias;
    }

    class WrappedDouble {
        private double[] d;

        public WrappedDouble(double[] d) {
            this.d = d;
        }

        public boolean equals(Object obj) {
            if (obj instanceof WrappedDouble) {
                WrappedDouble that = (WrappedDouble)obj;
                return Arrays.equals(this.d, that.d);
            }
            return false;
        }

        public int hashCode() {
            return Arrays.hashCode(this.d);
        }
    }
}

