/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.LogisticUtils;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.optimization.HasRegularizerParamRange;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

public class ShiftParamsLogisticObjectiveFunction
extends AbstractCachingDiffFunction
implements HasRegularizerParamRange {
    private final int[][] data;
    private final double[][] dataValues;
    private final int numClasses;
    private final int numFeatures;
    private final int[][] labels;
    private final int numL2Parameters;
    private final LogPrior prior;

    public ShiftParamsLogisticObjectiveFunction(int[][] data, double[][] dataValues, int[][] labels, int numClasses, int numFeatures, int numL2Parameters, LogPrior prior) {
        this.data = data;
        this.dataValues = dataValues;
        this.labels = labels;
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.numL2Parameters = numL2Parameters;
        this.prior = prior;
    }

    @Override
    public int domainDimension() {
        return (this.numClasses - 1) * this.numFeatures;
    }

    @Override
    protected void calculate(double[] thetasArray) {
        this.clearResults();
        double[][] thetas = new double[this.numClasses - 1][this.numFeatures];
        LogisticUtils.unflatten(thetasArray, thetas);
        for (int i = 0; i < this.data.length; ++i) {
            int[] featureIndices = this.data[i];
            double[] featureValues = this.dataValues[i];
            double[] sums = LogisticUtils.calculateSums(thetas, featureIndices, featureValues);
            for (int c = 0; c < this.numClasses; ++c) {
                double sum = sums[c];
                this.value -= sum * (double)this.labels[i][c];
                if (c == 0) continue;
                int offset = (c - 1) * this.numFeatures;
                double error = Math.exp(sum) - (double)this.labels[i][c];
                for (int f = 0; f < featureIndices.length; ++f) {
                    int index = featureIndices[f];
                    double x = featureValues[f];
                    int n = offset + index;
                    this.derivative[n] = this.derivative[n] - error * x;
                }
            }
        }
        if (this.prior.getType().equals((Object)LogPrior.LogPriorType.NULL)) {
            return;
        }
        double sigma = this.prior.getSigma();
        for (int c = 0; c < this.numClasses; ++c) {
            if (c == 0) continue;
            int offset = (c - 1) * this.numFeatures;
            for (int j = 0; j < this.numL2Parameters; ++j) {
                double theta = thetasArray[offset + j];
                this.value += theta * theta / (sigma * 2.0);
                int n = offset + j;
                this.derivative[n] = this.derivative[n] + theta / sigma;
            }
        }
    }

    private void clearResults() {
        this.value = 0.0;
        Arrays.fill(this.derivative, 0.0);
    }

    @Override
    public Set<Integer> getRegularizerParamRange(double[] x) {
        HashSet<Integer> result = new HashSet<Integer>();
        for (int i = this.numL2Parameters; i < x.length; ++i) {
            result.add(i);
        }
        return result;
    }
}

