package edu.stanford.nlp.coref.statistical;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintWriter;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

/* loaded from: input_file:edu/stanford/nlp/coref/statistical/SimpleLinearClassifier.class */
public class SimpleLinearClassifier {
    private static Redwood.RedwoodChannels log = Redwood.channels(SimpleLinearClassifier.class);
    private final Loss defaultLoss;
    private final LearningRateSchedule learningRateSchedule;
    private final double regularizationStrength;
    private final Counter<String> weights;
    private final Counter<String> accessTimes;
    private int examplesSeen;

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/SimpleLinearClassifier$CountBasedLearningRate.class */
    private static abstract class CountBasedLearningRate implements LearningRateSchedule {
        private final Counter<String> counter = new ClassicCounter();

        @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.LearningRateSchedule
        public void update(String str, double d) {
            this.counter.incrementCount(str, getCounterIncrement(d));
        }

        @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.LearningRateSchedule
        public double getLearningRate(String str) {
            return getLearningRate(this.counter.getCount(str));
        }

        public abstract double getCounterIncrement(double d);

        public abstract double getLearningRate(double d);
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/SimpleLinearClassifier$LearningRateSchedule.class */
    public interface LearningRateSchedule {
        void update(String str, double d);

        double getLearningRate(String str);
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/SimpleLinearClassifier$Loss.class */
    public interface Loss {
        double predict(double d);

        double derivative(double d, double d2);
    }

    public SimpleLinearClassifier(Loss loss, LearningRateSchedule learningRateSchedule, double d) {
        this(loss, learningRateSchedule, d, null);
    }

    public SimpleLinearClassifier(Loss loss, LearningRateSchedule learningRateSchedule, double d, String str) {
        if (str != null) {
            try {
                if (str.endsWith(".tab.gz")) {
                    Timing.startDoing("Reading " + str);
                    this.weights = Counters.deserializeStringCounter(str);
                    Timing.endDoing("Reading " + str);
                } else {
                    this.weights = (Counter) IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref model", str);
                }
            } catch (Exception e) {
                throw new RuntimeException("Error leading weights from " + str, e);
            }
        } else {
            this.weights = new ClassicCounter();
        }
        this.defaultLoss = loss;
        this.regularizationStrength = d;
        this.learningRateSchedule = learningRateSchedule;
        this.accessTimes = new ClassicCounter();
        this.examplesSeen = 0;
    }

    public void learn(Counter<String> counter, double d, double d2) {
        learn(counter, d, d2, this.defaultLoss);
    }

    public void learn(Counter<String> counter, double d, double d2, Loss loss) {
        this.examplesSeen++;
        double derivative = loss.derivative(d, weightFeatureProduct(counter));
        for (Map.Entry<String, Double> entry : counter.entrySet()) {
            double doubleValue = d2 * (-derivative) * entry.getValue().doubleValue();
            if (doubleValue != 0.0d) {
                String key = entry.getKey();
                this.learningRateSchedule.update(key, doubleValue);
                double learningRate = this.learningRateSchedule.getLearningRate(key);
                double count = this.weights.getCount(key);
                double signum = count - ((Math.signum(count) * ((d2 * this.regularizationStrength) * (this.examplesSeen - this.accessTimes.getCount(key)))) * learningRate);
                this.weights.setCount(key, (Math.signum(signum) != Math.signum(count) ? 0.0d : signum) + (doubleValue * learningRate));
                this.accessTimes.setCount(key, this.examplesSeen);
            }
        }
    }

    public double label(Counter<String> counter) {
        return this.defaultLoss.predict(weightFeatureProduct(counter));
    }

    public double weightFeatureProduct(Counter<String> counter) {
        double d = 0.0d;
        for (Map.Entry<String, Double> entry : counter.entrySet()) {
            d += entry.getValue().doubleValue() * this.weights.getCount(entry.getKey());
        }
        return d;
    }

    public void setWeight(String str, double d) {
        this.weights.setCount(str, d);
    }

    public SortedMap<String, Double> getWeightVector() {
        TreeMap treeMap = new TreeMap((str, str2) -> {
            double abs = Math.abs(this.weights.getCount(str2)) - Math.abs(this.weights.getCount(str));
            return abs == 0.0d ? str.compareTo(str2) : (int) Math.signum(abs);
        });
        this.weights.entrySet().stream().forEach(entry -> {
            treeMap.put((String) entry.getKey(), (Double) entry.getValue());
        });
        return treeMap;
    }

    public void printWeightVector() {
        printWeightVector(null);
    }

    public void printWeightVector(PrintWriter printWriter) {
        for (Map.Entry<String, Double> entry : getWeightVector().entrySet()) {
            if (printWriter == null) {
                Redwood.log("scoref.train", entry.getKey() + " => " + entry.getValue());
            } else {
                printWriter.println(entry.getKey() + " => " + entry.getValue());
            }
        }
    }

    public void writeWeights(String str) throws Exception {
        IOUtils.writeObjectToFile(this.weights, str);
    }

    public static Loss log() {
        return new Loss() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.1
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double predict(double d) {
                return 1.0d - (1.0d / (1.0d + Math.exp(d)));
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double derivative(double d, double d2) {
                return (-d) / (1.0d + Math.exp(d * d2));
            }

            public String toString() {
                return "log";
            }
        };
    }

    public static Loss quadraticallySmoothedSVM(final double d) {
        return new Loss() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.2
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double predict(double d2) {
                return d2;
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double derivative(double d2, double d3) {
                double d4 = d2 * d3;
                if (d4 >= 1.0d) {
                    return 0.0d;
                }
                return d4 >= 1.0d - d ? ((d4 - 1.0d) * d2) / d : -d2;
            }

            public String toString() {
                return String.format("quadraticallySmoothed(%s)", Double.valueOf(d));
            }
        };
    }

    public static Loss hinge() {
        return quadraticallySmoothedSVM(0.0d);
    }

    public static Loss maxMargin(final double d) {
        return new Loss() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.3
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double predict(double d2) {
                throw new UnsupportedOperationException("Predict not implemented for max margin");
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double derivative(double d2, double d3) {
                return d3 < (-d) ? 0.0d : 1.0d;
            }

            public String toString() {
                return String.format("max-margin(%s)", Double.valueOf(d));
            }
        };
    }

    public static Loss risk() {
        return new Loss() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.4
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double predict(double d) {
                return 1.0d / (1.0d + Math.exp(d));
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss
            public double derivative(double d, double d2) {
                return (-Math.exp(d2)) / Math.pow(1.0d + Math.exp(d2), 2.0d);
            }

            public String toString() {
                return "risk";
            }
        };
    }

    public static LearningRateSchedule constant(final double d) {
        return new LearningRateSchedule() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.5
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.LearningRateSchedule
            public double getLearningRate(String str) {
                return d;
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.LearningRateSchedule
            public void update(String str, double d2) {
            }

            public String toString() {
                return String.format("constant(%s)", Double.valueOf(d));
            }
        };
    }

    public static LearningRateSchedule invScaling(final double d, final double d2) {
        return new CountBasedLearningRate() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.6
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.CountBasedLearningRate
            public double getCounterIncrement(double d3) {
                return 1.0d;
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.CountBasedLearningRate
            public double getLearningRate(double d3) {
                return d / Math.pow(1.0d + d3, d2);
            }

            public String toString() {
                return String.format("invScaling(%s, %s)", Double.valueOf(d), Double.valueOf(d2));
            }
        };
    }

    public static LearningRateSchedule adaGrad(final double d, final double d2) {
        return new CountBasedLearningRate() { // from class: edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.7
            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.CountBasedLearningRate
            public double getCounterIncrement(double d3) {
                return d3 * d3;
            }

            @Override // edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.CountBasedLearningRate
            public double getLearningRate(double d3) {
                return d / (d2 + Math.sqrt(d3));
            }

            public String toString() {
                return String.format("adaGrad(%s, %s)", Double.valueOf(d), Double.valueOf(d2));
            }
        };
    }
}
