/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.stats;

import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.util.Generics;
import java.util.Set;

public class Distributions {
    private Distributions() {
    }

    protected static <K> Set<K> getSetOfAllKeys(Distribution<K> d1, Distribution<K> d2) {
        if (d1.getNumberOfKeys() != d2.getNumberOfKeys()) {
            throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.numberOfKeys != d2.numberOfKeys");
        }
        Set<K> allKeys = Generics.newHashSet(d1.getCounter().keySet());
        allKeys.addAll(d2.getCounter().keySet());
        if (allKeys.size() > d1.getNumberOfKeys()) {
            throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.counter intersect d2.counter > numberOfKeys");
        }
        return allKeys;
    }

    public static <K> double overlap(Distribution<K> d1, Distribution<K> d2) {
        Set<K> allKeys = Distributions.getSetOfAllKeys(d1, d2);
        double result = 0.0;
        double remainingMass1 = 1.0;
        double remainingMass2 = 1.0;
        for (K key : allKeys) {
            double p1 = d1.probabilityOf(key);
            double p2 = d2.probabilityOf(key);
            remainingMass1 -= p1;
            remainingMass2 -= p2;
            result += Math.min(p1, p2);
        }
        return result += Math.min(remainingMass1, remainingMass2);
    }

    public static <K> Distribution<K> weightedAverage(Distribution<K> d1, double w1, Distribution<K> d2) {
        double w2 = 1.0 - w1;
        Set<K> allKeys = Distributions.getSetOfAllKeys(d1, d2);
        int numKeys = d1.getNumberOfKeys();
        ClassicCounter<K> c = new ClassicCounter<K>();
        for (K key : allKeys) {
            double newProbability = d1.probabilityOf(key) * w1 + d2.probabilityOf(key) * w2;
            c.setCount(key, newProbability);
        }
        return Distribution.getDistributionFromPartiallySpecifiedCounter(c, numKeys);
    }

    public static <K> Distribution<K> average(Distribution<K> d1, Distribution<K> d2) {
        return Distributions.weightedAverage(d1, 0.5, d2);
    }

    public static <K> double klDivergence(Distribution<K> from, Distribution<K> to) {
        double p2;
        double p1;
        Set<K> allKeys = Distributions.getSetOfAllKeys(from, to);
        int numKeysRemaining = from.getNumberOfKeys();
        double result = 0.0;
        double assignedMass1 = 0.0;
        double assignedMass2 = 0.0;
        double log2 = Math.log(2.0);
        double epsilon = 1.0E-10;
        for (K key : allKeys) {
            p1 = from.probabilityOf(key);
            p2 = to.probabilityOf(key);
            --numKeysRemaining;
            assignedMass1 += p1;
            assignedMass2 += p2;
            if (p1 < epsilon) continue;
            double logFract = Math.log(p1 / p2);
            if (logFract == Double.POSITIVE_INFINITY) {
                System.out.println("Didtributions.kldivergence returning +inf: p1=" + p1 + ", p2=" + p2);
                System.out.flush();
                return Double.POSITIVE_INFINITY;
            }
            result += p1 * (logFract / log2);
        }
        if (numKeysRemaining != 0 && (p1 = (1.0 - assignedMass1) / (double)numKeysRemaining) > epsilon) {
            p2 = (1.0 - assignedMass2) / (double)numKeysRemaining;
            double logFract = Math.log(p1 / p2);
            if (logFract == Double.POSITIVE_INFINITY) {
                System.out.println("Distributions.klDivergence (remaining mass) returning +inf: p1=" + p1 + ", p2=" + p2);
                System.out.flush();
                return Double.POSITIVE_INFINITY;
            }
            result += (double)numKeysRemaining * p1 * (logFract / log2);
        }
        return result;
    }

    public static <K> double jensenShannonDivergence(Distribution<K> d1, Distribution<K> d2) {
        Distribution<K> average = Distributions.average(d1, d2);
        double kl1 = Distributions.klDivergence(d1, average);
        double kl2 = Distributions.klDivergence(d2, average);
        double js = (kl1 + kl2) / 2.0;
        return js;
    }

    public static <K> double skewDivergence(Distribution<K> d1, Distribution<K> d2, double skew) {
        Distribution<K> average = Distributions.weightedAverage(d2, skew, d1);
        return Distributions.klDivergence(d1, average);
    }

    public static <K> double informationRadius(Distribution<K> d1, Distribution<K> d2) {
        Distribution<K> avg = Distributions.average(d1, d2);
        return Distributions.klDivergence(d1, avg) + Distributions.klDivergence(d2, avg);
    }
}

