/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.FloatFactorTable;
import edu.stanford.nlp.ie.crf.HasCliquePotentialFunction;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFloatFunction;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Arrays;
import java.util.List;

public class CRFLogConditionalObjectiveFloatFunction
extends AbstractCachingDiffFloatFunction
implements HasCliquePotentialFunction {
    private static Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFloatFunction.class);
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    protected int prior;
    protected float sigma;
    protected float epsilon;
    List<Index<CRFLabel>> labelIndices;
    Index<String> classIndex;
    float[][] Ehat;
    int window;
    int numClasses;
    int[] map;
    int[][][][] data;
    int[][] labels;
    int domainDimension = -1;
    String backgroundSymbol;
    public static boolean VERBOSE = false;

    CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol) {
        this(data, labels, window, classIndex, labelIndices, map, 1, backgroundSymbol);
    }

    CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol, double sigma) {
        this(data, labels, window, classIndex, labelIndices, map, 1, backgroundSymbol, sigma);
    }

    CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, int prior, String backgroundSymbol) {
        this(data, labels, window, classIndex, labelIndices, map, prior, backgroundSymbol, 1.0);
    }

    CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, int prior, String backgroundSymbol, double sigma) {
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.labels = labels;
        this.prior = prior;
        this.backgroundSymbol = backgroundSymbol;
        this.sigma = (float)sigma;
        this.empiricalCounts(data, labels);
    }

    @Override
    public int domainDimension() {
        if (this.domainDimension < 0) {
            this.domainDimension = 0;
            for (int aMap : this.map) {
                this.domainDimension += this.labelIndices.get(aMap).size();
            }
        }
        return this.domainDimension;
    }

    @Override
    public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
        throw new UnsupportedOperationException("CRFLogConditionalObjectiveFloatFunction is not clique potential compatible yet");
    }

    public float[][] to2D(float[] weights) {
        float[][] newWeights = new float[this.map.length][];
        int index = 0;
        for (int i = 0; i < this.map.length; ++i) {
            newWeights[i] = new float[this.labelIndices.get(this.map[i]).size()];
            System.arraycopy(weights, index, newWeights[i], 0, this.labelIndices.get(this.map[i]).size());
            index += this.labelIndices.get(this.map[i]).size();
        }
        return newWeights;
    }

    public float[] to1D(float[][] weights) {
        float[] newWeights = new float[this.domainDimension()];
        int index = 0;
        for (float[] weight : weights) {
            System.arraycopy(weight, 0, newWeights, index, weight.length);
            index += weight.length;
        }
        return newWeights;
    }

    public float[][] empty2D() {
        float[][] d = new float[this.map.length][];
        for (int i = 0; i < this.map.length; ++i) {
            d[i] = new float[this.labelIndices.get(this.map[i]).size()];
        }
        return d;
    }

    private void empiricalCounts(int[][][][] data, int[][] labels) {
        this.Ehat = this.empty2D();
        for (int m = 0; m < data.length; ++m) {
            int[][][] dataDoc = data[m];
            int[] labelsDoc = labels[m];
            int[] label = new int[this.window];
            Arrays.fill(label, this.classIndex.indexOf(this.backgroundSymbol));
            for (int i = 0; i < dataDoc.length; ++i) {
                System.arraycopy(label, 1, label, 0, this.window - 1);
                label[this.window - 1] = labelsDoc[i];
                for (int j = 0; j < dataDoc[i].length; ++j) {
                    int[] cliqueLabel = new int[j + 1];
                    System.arraycopy(label, this.window - 1 - j, cliqueLabel, 0, j + 1);
                    CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                    int labelIndex = this.labelIndices.get(j).indexOf(crfLabel);
                    for (int k = 0; k < dataDoc[i][j].length; ++k) {
                        float[] fArray = this.Ehat[dataDoc[i][j][k]];
                        int n = labelIndex;
                        fArray[n] = fArray[n] + 1.0f;
                    }
                }
            }
        }
    }

    public static FloatFactorTable getFloatFactorTable(float[][] weights, int[][] data, List<Index<CRFLabel>> labelIndices, int numClasses) {
        FloatFactorTable factorTable = null;
        for (int j = 0; j < labelIndices.size(); ++j) {
            Index<CRFLabel> labelIndex = labelIndices.get(j);
            FloatFactorTable ft = new FloatFactorTable(numClasses, j + 1);
            for (int k = 0; k < labelIndex.size(); ++k) {
                int[] label = labelIndex.get(k).getLabel();
                float weight = 0.0f;
                for (int m = 0; m < data[j].length; ++m) {
                    weight += weights[data[j][m]][k];
                }
                ft.setValue(label, weight);
            }
            if (j > 0) {
                ft.multiplyInEnd(factorTable);
            }
            factorTable = ft;
        }
        return factorTable;
    }

    public static FloatFactorTable[] getCalibratedCliqueTree(float[][] weights, int[][][] data, List<Index<CRFLabel>> labelIndices, int numClasses) {
        int i;
        FloatFactorTable[] factorTables = new FloatFactorTable[data.length];
        FloatFactorTable[] messages = new FloatFactorTable[data.length - 1];
        for (i = 0; i < data.length; ++i) {
            factorTables[i] = CRFLogConditionalObjectiveFloatFunction.getFloatFactorTable(weights, data[i], labelIndices, numClasses);
            if (VERBOSE) {
                log.info(i + ": " + factorTables[i]);
            }
            if (i <= 0) continue;
            messages[i - 1] = factorTables[i - 1].sumOutFront();
            if (VERBOSE) {
                log.info(messages[i - 1]);
            }
            factorTables[i].multiplyInFront(messages[i - 1]);
            if (!VERBOSE) continue;
            log.info(factorTables[i]);
            if (i != data.length - 1) continue;
            log.info(i + ": " + factorTables[i].toProbString());
        }
        for (i = factorTables.length - 2; i >= 0; --i) {
            FloatFactorTable summedOut = factorTables[i + 1].sumOutEnd();
            if (VERBOSE) {
                log.info(i + 1 + "-->" + i + ": " + summedOut);
            }
            summedOut.divideBy(messages[i]);
            if (VERBOSE) {
                log.info(i + 1 + "-->" + i + ": " + summedOut);
            }
            factorTables[i].multiplyInEnd(summedOut);
            if (!VERBOSE) continue;
            log.info(i + ": " + factorTables[i]);
            log.info(i + ": " + factorTables[i].toProbString());
        }
        return factorTables;
    }

    @Override
    public void calculate(float[] x) {
        block15: {
            block16: {
                block14: {
                    float[][] weights = this.to2D(x);
                    float prob = 0.0f;
                    float[][] E = this.empty2D();
                    for (int m = 0; m < this.data.length; ++m) {
                        int i;
                        FloatFactorTable[] factorTables = CRFLogConditionalObjectiveFloatFunction.getCalibratedCliqueTree(weights, this.data[m], this.labelIndices, this.numClasses);
                        float z = factorTables[0].totalMass();
                        int[] given = new int[this.window - 1];
                        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
                        for (i = 0; i < this.data[m].length; ++i) {
                            float p = factorTables[i].conditionalLogProb(given, this.labels[m][i]);
                            if (VERBOSE) {
                                log.info("P(" + this.labels[m][i] + "|" + Arrays.toString(given) + ")=" + p);
                            }
                            prob += p;
                            System.arraycopy(given, 1, given, 0, given.length - 1);
                            given[given.length - 1] = this.labels[m][i];
                        }
                        for (i = 0; i < this.data[m].length; ++i) {
                            for (int j = 0; j < this.data[m][i].length; ++j) {
                                Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                                for (int k = 0; k < labelIndex.size(); ++k) {
                                    int[] label = labelIndex.get(k).getLabel();
                                    float p = (float)Math.exp(factorTables[i].unnormalizedLogProbEnd(label) - z);
                                    for (int n = 0; n < this.data[m][i][j].length; ++n) {
                                        float[] fArray = E[this.data[m][i][j][n]];
                                        int n2 = k;
                                        fArray[n2] = fArray[n2] + p;
                                    }
                                }
                            }
                        }
                    }
                    if (Float.isNaN(prob)) {
                        System.exit(0);
                    }
                    this.value = -prob;
                    int index = 0;
                    for (int i = 0; i < E.length; ++i) {
                        for (int j = 0; j < E[i].length; ++j) {
                            this.derivative[index++] = E[i][j] - this.Ehat[i][j];
                            if (!VERBOSE) continue;
                            log.info("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index - 1]);
                        }
                    }
                    if (this.prior != 1) break block14;
                    float sigmaSq = this.sigma * this.sigma;
                    int i = 0;
                    while (i < x.length) {
                        float k = 1.0f;
                        float w = x[i];
                        this.value = (float)((double)this.value + (double)(k * w * w) / 2.0 / (double)sigmaSq);
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + k * w / sigmaSq;
                    }
                    break block15;
                }
                if (this.prior != 2) break block16;
                float sigmaSq = this.sigma * this.sigma;
                for (int i = 0; i < x.length; ++i) {
                    float w = x[i];
                    float wabs = Math.abs(w);
                    if (wabs < this.epsilon) {
                        this.value = (float)((double)this.value + (double)(w * w) / 2.0 / (double)this.epsilon / (double)sigmaSq);
                        int n = i;
                        this.derivative[n] = this.derivative[n] + w / this.epsilon / sigmaSq;
                        continue;
                    }
                    this.value += (wabs - this.epsilon / 2.0f) / sigmaSq;
                    int n = i;
                    this.derivative[n] = (float)((double)this.derivative[n] + ((double)w < 0.0 ? -1.0 : 1.0) / (double)sigmaSq);
                }
                break block15;
            }
            if (this.prior != 3) break block15;
            float sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            int i = 0;
            while (i < x.length) {
                float k = 1.0f;
                float w = x[i];
                this.value = (float)((double)this.value + (double)(k * w * w * w * w) / 2.0 / (double)sigmaQu);
                int n = i++;
                this.derivative[n] = this.derivative[n] + k * w / sigmaQu;
            }
        }
    }

    public void calculateWeird1(float[] x) {
        block19: {
            int i;
            block20: {
                block18: {
                    float[][] weights = this.to2D(x);
                    float[][] E = this.empty2D();
                    this.value = 0.0f;
                    Arrays.fill(this.derivative, 0.0f);
                    float[][] sums = new float[this.labelIndices.size()][];
                    float[][] probs = new float[this.labelIndices.size()][];
                    float[][] counts = new float[this.labelIndices.size()][];
                    for (int i2 = 0; i2 < sums.length; ++i2) {
                        int size = this.labelIndices.get(i2).size();
                        sums[i2] = new float[size];
                        probs[i2] = new float[size];
                        counts[i2] = new float[size];
                    }
                    for (int d = 0; d < this.data.length; ++d) {
                        int[] llabels = this.labels[d];
                        for (int e = 0; e < this.data[d].length; ++e) {
                            int cl;
                            int[][] ddata = this.data[d][e];
                            for (cl = 0; cl < ddata.length; ++cl) {
                                int[] features = ddata[cl];
                                Arrays.fill(sums[cl], 0.0f);
                                int numClasses = this.labelIndices.get(cl).size();
                                for (int c = 0; c < numClasses; ++c) {
                                    for (int feature : features) {
                                        float[] fArray = sums[cl];
                                        int n = c;
                                        fArray[n] = fArray[n] + weights[feature][c];
                                    }
                                }
                            }
                            for (cl = 0; cl < ddata.length; ++cl) {
                                int[] label = new int[cl + 1];
                                Arrays.fill(label, this.classIndex.indexOf(this.backgroundSymbol));
                                int index1 = label.length - 1;
                                for (int pos = e; pos >= 0 && index1 >= 0; --pos) {
                                    label[index1--] = llabels[pos];
                                }
                                CRFLabel crfLabel = new CRFLabel(label);
                                int labelIndex = this.labelIndices.get(cl).indexOf(crfLabel);
                                float total = ArrayMath.logSum(sums[cl]);
                                int numClasses = this.labelIndices.get(cl).size();
                                for (int c = 0; c < numClasses; ++c) {
                                    probs[cl][c] = (float)Math.exp(sums[cl][c] - total);
                                }
                                this.value -= sums[cl][labelIndex] - total;
                            }
                            for (int j = 0; j < this.data[d][e].length; ++j) {
                                Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                                for (int k = 0; k < labelIndex.size(); ++k) {
                                    float p = probs[j][k];
                                    for (int n = 0; n < this.data[d][e][j].length; ++n) {
                                        float[] fArray = E[this.data[d][e][j][n]];
                                        int n2 = k;
                                        fArray[n2] = fArray[n2] + p;
                                    }
                                }
                            }
                        }
                    }
                    int index = 0;
                    for (int i3 = 0; i3 < E.length; ++i3) {
                        for (int j = 0; j < E[i3].length; ++j) {
                            this.derivative[index++] = E[i3][j] - this.Ehat[i3][j];
                        }
                    }
                    if (this.prior != 1) break block18;
                    float sigmaSq = this.sigma * this.sigma;
                    i = 0;
                    while (i < x.length) {
                        float k = 1.0f;
                        float w = x[i];
                        this.value = (float)((double)this.value + (double)(k * w * w) / 2.0 / (double)sigmaSq);
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + k * w / sigmaSq;
                    }
                    break block19;
                }
                if (this.prior != 2) break block20;
                float sigmaSq = this.sigma * this.sigma;
                for (i = 0; i < x.length; ++i) {
                    float w = x[i];
                    float wabs = Math.abs(w);
                    if (wabs < this.epsilon) {
                        this.value = (float)((double)this.value + (double)(w * w) / 2.0 / (double)this.epsilon / (double)sigmaSq);
                        int n = i;
                        this.derivative[n] = this.derivative[n] + w / this.epsilon / sigmaSq;
                        continue;
                    }
                    this.value += (wabs - this.epsilon / 2.0f) / sigmaSq;
                    int n = i;
                    this.derivative[n] = (float)((double)this.derivative[n] + ((double)w < 0.0 ? -1.0 : 1.0) / (double)sigmaSq);
                }
                break block19;
            }
            if (this.prior != 3) break block19;
            float sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            i = 0;
            while (i < x.length) {
                float k = 1.0f;
                float w = x[i];
                this.value = (float)((double)this.value + (double)(k * w * w * w * w) / 2.0 / (double)sigmaQu);
                int n = i++;
                this.derivative[n] = this.derivative[n] + k * w / sigmaQu;
            }
        }
    }
}

