/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.inference;

import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.loglinear.model.NDArrayDoubles;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.function.BiFunction;
import java.util.function.Supplier;

public class TableFactor
extends NDArrayDoubles {
    public int[] neighborIndices;
    public static final boolean USE_EXP_APPROX = false;

    public TableFactor(ConcatVector weights, GraphicalModel.Factor factor) {
        super(factor.featuresTable.getDimensions());
        this.neighborIndices = factor.neigborIndices;
        Iterator<int[]> fastPassByReferenceIterator = factor.featuresTable.fastPassByReferenceIterator();
        int[] assignment = fastPassByReferenceIterator.next();
        while (true) {
            this.setAssignmentLogValue(assignment, ((ConcatVector)((Supplier)factor.featuresTable.getAssignmentValue(assignment)).get()).dotProduct(weights));
            if (!fastPassByReferenceIterator.hasNext()) break;
            fastPassByReferenceIterator.next();
        }
    }

    public static double exp(double val) {
        long tmp = (long)(1512775.0 * val + 1.072632447E9);
        return Double.longBitsToDouble(tmp << 32);
    }

    public TableFactor(ConcatVector weights, GraphicalModel.Factor factor, int[] observations) {
        assert (observations.length == factor.neigborIndices.length);
        int size = 0;
        for (int observation : observations) {
            if (observation != -1) continue;
            ++size;
        }
        this.neighborIndices = new int[size];
        this.dimensions = new int[size];
        int[] forwardPointers = new int[size];
        int[] factorAssignment = new int[factor.neigborIndices.length];
        int cursor = 0;
        for (int i = 0; i < factor.neigborIndices.length; ++i) {
            if (observations[i] == -1) {
                this.neighborIndices[cursor] = factor.neigborIndices[i];
                this.dimensions[cursor] = factor.featuresTable.getDimensions()[i];
                forwardPointers[cursor] = i;
                ++cursor;
                continue;
            }
            factorAssignment[i] = observations[i];
        }
        assert (cursor == size);
        this.values = new double[this.combinatorialNeighborStatesCount()];
        for (int[] assn : this) {
            for (int i = 0; i < assn.length; ++i) {
                factorAssignment[forwardPointers[i]] = assn[i];
            }
            this.setAssignmentLogValue(assn, ((ConcatVector)((Supplier)factor.featuresTable.getAssignmentValue(factorAssignment)).get()).dotProduct(weights));
        }
    }

    public TableFactor observe(int variable, int value) {
        return this.marginalize(variable, 0.0, (marginalizedVariableValue, assignment) -> {
            if (marginalizedVariableValue == value) {
                return (old, n) -> n;
            }
            return (old, n) -> old;
        });
    }

    public double[][] getSummedMarginals() {
        double[][] results = new double[this.neighborIndices.length][];
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            results[i] = new double[this.getDimensions()[i]];
        }
        double[][] maxValues = new double[this.neighborIndices.length][];
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            maxValues[i] = new double[this.getDimensions()[i]];
            for (int j = 0; j < maxValues[i].length; ++j) {
                maxValues[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
        Iterator<int[]> fastPassByReferenceIterator = this.fastPassByReferenceIterator();
        int[] assignment = fastPassByReferenceIterator.next();
        while (true) {
            double v = this.getAssignmentLogValue(assignment);
            for (int i = 0; i < this.neighborIndices.length; ++i) {
                if (!(maxValues[i][assignment[i]] < v)) continue;
                maxValues[i][assignment[i]] = v;
            }
            if (!fastPassByReferenceIterator.hasNext()) break;
            fastPassByReferenceIterator.next();
        }
        Iterator<int[]> secondFastPassByReferenceIterator = this.fastPassByReferenceIterator();
        assignment = secondFastPassByReferenceIterator.next();
        while (true) {
            double v = this.getAssignmentLogValue(assignment);
            for (int i = 0; i < this.neighborIndices.length; ++i) {
                double[] dArray = results[i];
                int n = assignment[i];
                dArray[n] = dArray[n] + Math.exp(v - maxValues[i][assignment[i]]);
            }
            if (!secondFastPassByReferenceIterator.hasNext()) break;
            secondFastPassByReferenceIterator.next();
        }
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < results[i].length; ++j) {
                results[i][j] = Math.exp(maxValues[i][j]) * results[i][j];
                sum += results[i][j];
            }
            if (Double.isInfinite(sum)) {
                for (j = 0; j < results[i].length; ++j) {
                    results[i][j] = 1.0 / (double)results[i].length;
                }
                continue;
            }
            j = 0;
            while (j < results[i].length) {
                double[] dArray = results[i];
                int n = j++;
                dArray[n] = dArray[n] / sum;
            }
        }
        return results;
    }

    public double[][] getMaxedMarginals() {
        double[][] maxValues = new double[this.neighborIndices.length][];
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            maxValues[i] = new double[this.getDimensions()[i]];
            for (int j = 0; j < maxValues[i].length; ++j) {
                maxValues[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
        Iterator<int[]> fastPassByReferenceIterator = this.fastPassByReferenceIterator();
        int[] assignment = fastPassByReferenceIterator.next();
        while (true) {
            double v = this.getAssignmentLogValue(assignment);
            for (int i = 0; i < this.neighborIndices.length; ++i) {
                if (!(maxValues[i][assignment[i]] < v)) continue;
                maxValues[i][assignment[i]] = v;
            }
            if (!fastPassByReferenceIterator.hasNext()) break;
            fastPassByReferenceIterator.next();
        }
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            TableFactor.normalizeLogArr(maxValues[i]);
        }
        return maxValues;
    }

    public TableFactor maxOut(int variable) {
        return this.marginalize(variable, Double.NEGATIVE_INFINITY, (marginalizedVariableValue, assignment) -> Math::max);
    }

    public TableFactor sumOut(int variable) {
        if (this.getDimensions().length == 2) {
            int index;
            int j;
            int k;
            int i;
            if (this.neighborIndices[0] == variable) {
                int index2;
                int j2;
                int k2;
                int i2;
                int j3;
                TableFactor marginalized = new TableFactor(new int[]{this.neighborIndices[1]}, new int[]{this.getDimensions()[1]});
                for (int i3 = 0; i3 < marginalized.values.length; ++i3) {
                    marginalized.values[i3] = 0.0;
                }
                double[] max = new double[this.getDimensions()[1]];
                for (j3 = 0; j3 < this.getDimensions()[1]; ++j3) {
                    max[j3] = Double.NEGATIVE_INFINITY;
                }
                for (i2 = 0; i2 < this.getDimensions()[0]; ++i2) {
                    k2 = i2 * this.getDimensions()[1];
                    for (j2 = 0; j2 < this.getDimensions()[1]; ++j2) {
                        index2 = k2 + j2;
                        if (!(this.values[index2] > max[j2])) continue;
                        max[j2] = this.values[index2];
                    }
                }
                for (i2 = 0; i2 < this.getDimensions()[0]; ++i2) {
                    k2 = i2 * this.getDimensions()[1];
                    for (j2 = 0; j2 < this.getDimensions()[1]; ++j2) {
                        index2 = k2 + j2;
                        if (!Double.isFinite(max[j2])) continue;
                        int n = j2;
                        marginalized.values[n] = marginalized.values[n] + Math.exp(this.values[index2] - max[j2]);
                    }
                }
                for (j3 = 0; j3 < this.getDimensions()[1]; ++j3) {
                    marginalized.values[j3] = Double.isFinite(max[j3]) ? max[j3] + Math.log(marginalized.values[j3]) : max[j3];
                }
                return marginalized;
            }
            assert (this.neighborIndices[1] == variable);
            TableFactor marginalized = new TableFactor(new int[]{this.neighborIndices[0]}, new int[]{this.getDimensions()[0]});
            for (int i4 = 0; i4 < marginalized.values.length; ++i4) {
                marginalized.values[i4] = 0.0;
            }
            double[] max = new double[this.getDimensions()[0]];
            for (i = 0; i < this.getDimensions()[0]; ++i) {
                max[i] = Double.NEGATIVE_INFINITY;
            }
            for (i = 0; i < this.getDimensions()[0]; ++i) {
                k = i * this.getDimensions()[1];
                for (j = 0; j < this.getDimensions()[1]; ++j) {
                    index = k + j;
                    if (!(this.values[index] > max[i])) continue;
                    max[i] = this.values[index];
                }
            }
            for (i = 0; i < this.getDimensions()[0]; ++i) {
                k = i * this.getDimensions()[1];
                for (j = 0; j < this.getDimensions()[1]; ++j) {
                    index = k + j;
                    if (!Double.isFinite(max[i])) continue;
                    int n = i;
                    marginalized.values[n] = marginalized.values[n] + Math.exp(this.values[index] - max[i]);
                }
            }
            for (i = 0; i < this.getDimensions()[0]; ++i) {
                marginalized.values[i] = Double.isFinite(max[i]) ? max[i] + Math.log(marginalized.values[i]) : max[i];
            }
            return marginalized;
        }
        TableFactor maxValues = this.maxOut(variable);
        TableFactor marginalized = this.marginalize(variable, 0.0, (marginalizedVariableValue, assignment) -> (a, b) -> a + Math.exp(b - maxValues.getAssignmentLogValue((int[])assignment)));
        for (int[] assignment2 : marginalized) {
            marginalized.setAssignmentLogValue(assignment2, maxValues.getAssignmentLogValue(assignment2) + Math.log(marginalized.getAssignmentLogValue(assignment2)));
        }
        return marginalized;
    }

    public TableFactor multiply(TableFactor other) {
        ArrayList<Integer> domain = new ArrayList<Integer>();
        ArrayList<Integer> otherDomain = new ArrayList<Integer>();
        ArrayList<Integer> resultDomain = new ArrayList<Integer>();
        for (int n : this.neighborIndices) {
            domain.add(n);
            resultDomain.add(n);
        }
        for (int n : other.neighborIndices) {
            otherDomain.add(n);
            if (resultDomain.contains(n)) continue;
            resultDomain.add(n);
        }
        int[] resultNeighborIndices = new int[resultDomain.size()];
        int[] resultDimensions = new int[resultNeighborIndices.length];
        for (int i = 0; i < resultDomain.size(); ++i) {
            int var;
            resultNeighborIndices[i] = var = ((Integer)resultDomain.get(i)).intValue();
            assert (this.getVariableSize(var) == 0 && other.getVariableSize(var) > 0 || this.getVariableSize(var) > 0 && other.getVariableSize(var) == 0 || this.getVariableSize(var) == other.getVariableSize(var));
            resultDimensions[i] = Math.max(this.getVariableSize((Integer)resultDomain.get(i)), other.getVariableSize((Integer)resultDomain.get(i)));
        }
        TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);
        if (otherDomain.size() == 1 && resultDomain.size() == domain.size() && domain.size() == 2) {
            int msgVar = (Integer)otherDomain.get(0);
            int msgIndex = resultDomain.indexOf(msgVar);
            if (msgIndex == 0) {
                for (int i = 0; i < resultDimensions[0]; ++i) {
                    double d = other.values[i];
                    int k = i * resultDimensions[1];
                    for (int j = 0; j < resultDimensions[1]; ++j) {
                        int index = k + j;
                        result.values[index] = this.values[index] + d;
                    }
                }
            } else if (msgIndex == 1) {
                for (int i = 0; i < resultDimensions[0]; ++i) {
                    int k = i * resultDimensions[1];
                    for (int j = 0; j < resultDimensions[1]; ++j) {
                        int index = k + j;
                        result.values[index] = this.values[index] + other.values[j];
                    }
                }
            }
        } else {
            if (domain.size() == 1 && resultDomain.size() == otherDomain.size() && resultDomain.size() == 2) {
                return other.multiply(this);
            }
            int[] mapping = new int[result.neighborIndices.length];
            int[] otherMapping = new int[result.neighborIndices.length];
            for (int i = 0; i < result.neighborIndices.length; ++i) {
                mapping[i] = domain.indexOf(result.neighborIndices[i]);
                otherMapping[i] = otherDomain.indexOf(result.neighborIndices[i]);
            }
            int[] assignment = new int[this.neighborIndices.length];
            int[] otherAssignment = new int[other.neighborIndices.length];
            Iterator<int[]> fastPassByReferenceIterator = result.fastPassByReferenceIterator();
            int[] resultAssignment = fastPassByReferenceIterator.next();
            while (true) {
                for (int i = 0; i < resultAssignment.length; ++i) {
                    if (mapping[i] != -1) {
                        assignment[mapping[i]] = resultAssignment[i];
                    }
                    if (otherMapping[i] == -1) continue;
                    otherAssignment[otherMapping[i]] = resultAssignment[i];
                }
                result.setAssignmentLogValue(resultAssignment, this.getAssignmentLogValue(assignment) + other.getAssignmentLogValue(otherAssignment));
                if (!fastPassByReferenceIterator.hasNext()) break;
                fastPassByReferenceIterator.next();
            }
        }
        return result;
    }

    public double valueSum() {
        double max = 0.0;
        for (int[] assignment : this) {
            double v = this.getAssignmentLogValue(assignment);
            if (!(v > max)) continue;
            max = v;
        }
        double sumExp = 0.0;
        for (int[] assignment : this) {
            sumExp += Math.exp(this.getAssignmentLogValue(assignment) - max);
        }
        return sumExp * Math.exp(max);
    }

    @Override
    public double getAssignmentValue(int[] assignment) {
        double d = super.getAssignmentValue(assignment);
        return Math.exp(d);
    }

    @Override
    public void setAssignmentValue(int[] assignment, double value) {
        super.setAssignmentValue(assignment, Math.log(value));
    }

    private double getAssignmentLogValue(int[] assignment) {
        return super.getAssignmentValue(assignment);
    }

    private void setAssignmentLogValue(int[] assignment, double value) {
        super.setAssignmentValue(assignment, value);
    }

    private TableFactor marginalize(int variable, double startingValue, BiFunction<Integer, int[], BiFunction<Double, Double, Double>> curriedFoldr) {
        assert (this.getDimensions().length > 1);
        ArrayList<Integer> resultDomain = new ArrayList<Integer>();
        for (int n : this.neighborIndices) {
            if (n == variable) continue;
            resultDomain.add(n);
        }
        int[] resultNeighborIndices = new int[resultDomain.size()];
        int[] resultDimensions = new int[resultNeighborIndices.length];
        for (int i = 0; i < resultDomain.size(); ++i) {
            int var;
            resultNeighborIndices[i] = var = ((Integer)resultDomain.get(i)).intValue();
            resultDimensions[i] = this.getVariableSize(var);
        }
        TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);
        int[] mapping = new int[this.neighborIndices.length];
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            mapping[i] = resultDomain.indexOf(this.neighborIndices[i]);
        }
        for (int[] assignment : result) {
            result.setAssignmentLogValue(assignment, startingValue);
        }
        int[] resultAssignment = new int[result.neighborIndices.length];
        int marginalizedVariableValue = 0;
        Iterator<int[]> fastPassByReferenceIterator = this.fastPassByReferenceIterator();
        int[] assignment = fastPassByReferenceIterator.next();
        while (true) {
            for (int i = 0; i < assignment.length; ++i) {
                if (mapping[i] != -1) {
                    resultAssignment[mapping[i]] = assignment[i];
                    continue;
                }
                marginalizedVariableValue = assignment[i];
            }
            result.setAssignmentLogValue(resultAssignment, curriedFoldr.apply(marginalizedVariableValue, resultAssignment).apply(result.getAssignmentLogValue(resultAssignment), this.getAssignmentLogValue(assignment)));
            if (!fastPassByReferenceIterator.hasNext()) break;
            fastPassByReferenceIterator.next();
        }
        return result;
    }

    private int getVariableSize(int variable) {
        for (int i = 0; i < this.neighborIndices.length; ++i) {
            if (this.neighborIndices[i] != variable) continue;
            return this.getDimensions()[i];
        }
        return 0;
    }

    private static void normalizeLogArr(double[] arr) {
        int i;
        double max = Double.NEGATIVE_INFINITY;
        for (double d : arr) {
            if (!(d > max)) continue;
            max = d;
        }
        double expSum = 0.0;
        for (double d : arr) {
            expSum += Math.exp(d - max);
        }
        double logSumExp = max + Math.log(expSum);
        if (Double.isInfinite(logSumExp)) {
            for (i = 0; i < arr.length; ++i) {
                arr[i] = 1.0 / (double)arr.length;
            }
        } else {
            for (i = 0; i < arr.length; ++i) {
                arr[i] = Math.exp(arr[i] - logSumExp);
            }
        }
    }

    TableFactor(int[] neighborIndices, int[] dimensions) {
        super(dimensions);
        this.neighborIndices = neighborIndices;
        for (int i = 0; i < this.values.length; ++i) {
            this.values[i] = Double.NEGATIVE_INFINITY;
        }
    }

    private boolean assertsEnabled() {
        boolean assertsEnabled = false;
        if (!$assertionsDisabled) {
            assertsEnabled = true;
            if (!true) {
                throw new AssertionError();
            }
        }
        return assertsEnabled;
    }
}

