/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.annotation.basic;

import cern.jet.random.EmpiricalWalker;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.RandomEngine;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TObjectIntProcedure;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.BatchAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;
import org.openimaj.ml.annotation.basic.util.NumAnnotationsChooser;

public class IndependentPriorRandomAnnotator<OBJECT, ANNOTATION>
extends BatchAnnotator<OBJECT, ANNOTATION> {
    protected List<ANNOTATION> annotations;
    protected NumAnnotationsChooser numAnnotations;
    protected EmpiricalWalker annotationProbability;

    public IndependentPriorRandomAnnotator(NumAnnotationsChooser chooser) {
        this.numAnnotations = chooser;
    }

    @Override
    public void train(List<? extends Annotated<OBJECT, ANNOTATION>> data) {
        TIntIntHashMap nAnnotationCounts = new TIntIntHashMap();
        TObjectIntHashMap annotationCounts = new TObjectIntHashMap();
        int maxVal = 0;
        for (Annotated<OBJECT, ANNOTATION> sample : data) {
            Collection<ANNOTATION> annos = sample.getAnnotations();
            for (ANNOTATION s : annos) {
                annotationCounts.adjustOrPutValue(s, 1, 1);
            }
            nAnnotationCounts.adjustOrPutValue(annos.size(), 1, 1);
            if (annos.size() <= maxVal) continue;
            maxVal = annos.size();
        }
        this.annotations = new ArrayList<ANNOTATION>();
        final TDoubleArrayList probs = new TDoubleArrayList();
        annotationCounts.forEachEntry(new TObjectIntProcedure<ANNOTATION>(){

            public boolean execute(ANNOTATION a, int b) {
                IndependentPriorRandomAnnotator.this.annotations.add(a);
                probs.add((double)b);
                return true;
            }
        });
        this.annotationProbability = new EmpiricalWalker(probs.toArray(), 1, (RandomEngine)new MersenneTwister());
        this.numAnnotations.train(data);
    }

    @Override
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT image) {
        int nAnnotations = this.numAnnotations.numAnnotations();
        ArrayList<ScoredAnnotation<ANNOTATION>> annos = new ArrayList<ScoredAnnotation<ANNOTATION>>();
        for (int i = 0; i < nAnnotations; ++i) {
            int annotationIdx = this.annotationProbability.nextInt();
            annos.add(new ScoredAnnotation<ANNOTATION>(this.annotations.get(annotationIdx), (float)this.annotationProbability.pdf(annotationIdx + 1)));
        }
        return annos;
    }

    @Override
    public Set<ANNOTATION> getAnnotations() {
        return new HashSet<ANNOTATION>(this.annotations);
    }
}

