/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.neural.rnn.TopNGramRecord;
import edu.stanford.nlp.sentiment.RNNOptions;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.ConfusionMatrix;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.logging.Redwood;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

public abstract class AbstractEvaluate {
    private static Redwood.RedwoodChannels log = Redwood.channels(AbstractEvaluate.class);
    String[] equivalenceClassNames;
    int labelsCorrect;
    int labelsIncorrect;
    int[][] labelConfusion;
    int rootLabelsCorrect;
    int rootLabelsIncorrect;
    int[][] rootLabelConfusion;
    IntCounter<Integer> lengthLabelsCorrect;
    IntCounter<Integer> lengthLabelsIncorrect;
    TopNGramRecord ngrams;
    static final int NUM_NGRAMS = 5;
    int[][] equivalenceClasses;
    protected static final NumberFormat NF = new DecimalFormat("0.000000");
    private RNNOptions op = null;

    public AbstractEvaluate(RNNOptions options) {
        this.op = options;
        this.reset();
    }

    protected static void printConfusionMatrix(String name, int[][] confusion) {
        log.info(name + " confusion matrix");
        ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix<Integer>();
        confusionMatrix.setUseRealLabels(true);
        for (int i = 0; i < confusion.length; ++i) {
            for (int j = 0; j < confusion[i].length; ++j) {
                confusionMatrix.add(j, i, confusion[i][j]);
            }
        }
        log.info(confusionMatrix);
    }

    protected static double[] approxAccuracy(int[][] confusion, int[][] classes) {
        int[] correct = new int[classes.length];
        int[] total = new int[classes.length];
        double[] results = new double[classes.length];
        for (int i = 0; i < classes.length; ++i) {
            for (int j = 0; j < classes[i].length; ++j) {
                int k;
                for (k = 0; k < classes[i].length; ++k) {
                    int n = i;
                    correct[n] = correct[n] + confusion[classes[i][j]][classes[i][k]];
                }
                for (k = 0; k < confusion[classes[i][j]].length; ++k) {
                    int n = i;
                    total[n] = total[n] + confusion[classes[i][j]][k];
                }
            }
            results[i] = (double)correct[i] / (double)total[i];
        }
        return results;
    }

    protected static double approxCombinedAccuracy(int[][] confusion, int[][] classes) {
        int correct = 0;
        int total = 0;
        for (int[] aClass : classes) {
            for (int j = 0; j < aClass.length; ++j) {
                int k;
                for (k = 0; k < aClass.length; ++k) {
                    correct += confusion[aClass[j]][aClass[k]];
                }
                for (k = 0; k < confusion[aClass[j]].length; ++k) {
                    total += confusion[aClass[j]][k];
                }
            }
        }
        return (double)correct / (double)total;
    }

    public void reset() {
        this.labelsCorrect = 0;
        this.labelsIncorrect = 0;
        this.labelConfusion = new int[this.op.numClasses][this.op.numClasses];
        this.rootLabelsCorrect = 0;
        this.rootLabelsIncorrect = 0;
        this.rootLabelConfusion = new int[this.op.numClasses][this.op.numClasses];
        this.lengthLabelsCorrect = new IntCounter();
        this.lengthLabelsIncorrect = new IntCounter();
        this.equivalenceClasses = this.op.equivalenceClasses;
        this.equivalenceClassNames = this.op.equivalenceClassNames;
        this.ngrams = this.op.testOptions.ngramRecordSize > 0 ? new TopNGramRecord(this.op.numClasses, this.op.testOptions.ngramRecordSize, this.op.testOptions.ngramRecordMaximumLength) : null;
    }

    public void eval(List<Tree> trees) {
        this.populatePredictedLabels(trees);
        for (Tree tree : trees) {
            this.eval(tree);
        }
    }

    public void eval(Tree tree) {
        this.countTree(tree);
        this.countRoot(tree);
        this.countLengthAccuracy(tree);
        if (this.ngrams != null) {
            this.ngrams.countTree(tree);
        }
    }

    protected int countLengthAccuracy(Tree tree) {
        int length;
        if (tree.isLeaf()) {
            return 0;
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (tree.isPreTerminal()) {
            length = 1;
        } else {
            length = 0;
            for (Tree child : tree.children()) {
                length += this.countLengthAccuracy(child);
            }
        }
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                this.lengthLabelsCorrect.incrementCount(length);
            } else {
                this.lengthLabelsIncorrect.incrementCount(length);
            }
        }
        return length;
    }

    protected void countTree(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        for (Tree child : tree.children()) {
            this.countTree(child);
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                ++this.labelsCorrect;
            } else {
                ++this.labelsIncorrect;
            }
            int[] nArray = this.labelConfusion[gold];
            int n = predicted;
            nArray[n] = nArray[n] + 1;
        }
    }

    protected void countRoot(Tree tree) {
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                ++this.rootLabelsCorrect;
            } else {
                ++this.rootLabelsIncorrect;
            }
            int[] nArray = this.rootLabelConfusion[gold];
            int n = predicted;
            nArray[n] = nArray[n] + 1;
        }
    }

    public double exactNodeAccuracy() {
        return (double)this.labelsCorrect / (double)(this.labelsCorrect + this.labelsIncorrect);
    }

    public double exactRootAccuracy() {
        return (double)this.rootLabelsCorrect / (double)(this.rootLabelsCorrect + this.rootLabelsIncorrect);
    }

    public Counter<Integer> lengthAccuracies() {
        Set<Integer> keys = Generics.newHashSet();
        keys.addAll(this.lengthLabelsCorrect.keySet());
        keys.addAll(this.lengthLabelsIncorrect.keySet());
        ClassicCounter<Integer> results = new ClassicCounter<Integer>();
        for (Integer key : keys) {
            results.setCount(key, this.lengthLabelsCorrect.getCount(key) / (this.lengthLabelsCorrect.getCount(key) + this.lengthLabelsIncorrect.getCount(key)));
        }
        return results;
    }

    public void printLengthAccuracies() {
        Counter<Integer> accuracies = this.lengthAccuracies();
        TreeSet<Integer> keys = Generics.newTreeSet();
        keys.addAll(accuracies.keySet());
        log.info("Label accuracy at various lengths:");
        for (Integer key : keys) {
            log.info(StringUtils.padLeft(Integer.toString(key), 4) + ": " + NF.format(accuracies.getCount(key)));
        }
    }

    public void printSummary() {
        log.info("EVALUATION SUMMARY");
        log.info("Tested " + (this.labelsCorrect + this.labelsIncorrect) + " labels");
        log.info("  " + this.labelsCorrect + " correct");
        log.info("  " + this.labelsIncorrect + " incorrect");
        log.info("  " + NF.format(this.exactNodeAccuracy()) + " accuracy");
        log.info("Tested " + (this.rootLabelsCorrect + this.rootLabelsIncorrect) + " roots");
        log.info("  " + this.rootLabelsCorrect + " correct");
        log.info("  " + this.rootLabelsIncorrect + " incorrect");
        log.info("  " + NF.format(this.exactRootAccuracy()) + " accuracy");
        AbstractEvaluate.printConfusionMatrix("Label", this.labelConfusion);
        AbstractEvaluate.printConfusionMatrix("Root label", this.rootLabelConfusion);
        if (this.equivalenceClasses != null && this.equivalenceClassNames != null) {
            double[] approxLabelAccuracy = AbstractEvaluate.approxAccuracy(this.labelConfusion, this.equivalenceClasses);
            for (int i = 0; i < this.equivalenceClassNames.length; ++i) {
                log.info("Approximate " + this.equivalenceClassNames[i] + " label accuracy: " + NF.format(approxLabelAccuracy[i]));
            }
            log.info("Combined approximate label accuracy: " + NF.format(AbstractEvaluate.approxCombinedAccuracy(this.labelConfusion, this.equivalenceClasses)));
            double[] approxRootLabelAccuracy = AbstractEvaluate.approxAccuracy(this.rootLabelConfusion, this.equivalenceClasses);
            for (int i = 0; i < this.equivalenceClassNames.length; ++i) {
                log.info("Approximate " + this.equivalenceClassNames[i] + " root label accuracy: " + NF.format(approxRootLabelAccuracy[i]));
            }
            log.info("Combined approximate root label accuracy: " + NF.format(AbstractEvaluate.approxCombinedAccuracy(this.rootLabelConfusion, this.equivalenceClasses)));
            log.info(new Object[0]);
        }
        if (this.op.testOptions.ngramRecordSize > 0) {
            log.info(this.ngrams);
        }
        if (this.op.testOptions.printLengthAccuracies) {
            this.printLengthAccuracies();
        }
    }

    public abstract void populatePredictedLabels(List<Tree> var1);
}

