/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.learning;

import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.InRange;
import edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer;
import edu.stanford.nlp.loglinear.learning.AbstractDifferentiableFunction;
import edu.stanford.nlp.loglinear.learning.BacktrackingAdaGradOptimizer;
import edu.stanford.nlp.loglinear.learning.LogLikelihoodDifferentiableFunction;
import edu.stanford.nlp.loglinear.learning.LogLikelihoodFunctionTest;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.Random;
import org.junit.Assert;
import org.junit.contrib.theories.DataPoint;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;

@RunWith(value=Theories.class)
public class OptimizerTests {
    @DataPoint
    public static AbstractBatchOptimizer backtrackingAdaGrad = new BacktrackingAdaGradOptimizer();

    @Theory
    public void testOptimizeLogLikelihood(AbstractBatchOptimizer optimizer, @ForAll(sampleSize=5) @From(value={LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class}) GraphicalModel[] dataset, @ForAll(sampleSize=2) @From(value={LogLikelihoodFunctionTest.WeightsGenerator.class}) ConcatVector initialWeights, @ForAll(sampleSize=2) @InRange(minDouble=0.0, maxDouble=5.0) double l2regularization) throws Exception {
        LogLikelihoodDifferentiableFunction ll = new LogLikelihoodDifferentiableFunction();
        ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0E-9, true);
        System.err.println("Finished optimizing");
        double logLikelihood = this.getValueSum(dataset, finalWeights, ll, l2regularization);
        Random r = new Random(42L);
        for (int i = 0; i < 1000; ++i) {
            int size = finalWeights.getNumberOfComponents();
            ConcatVector randomDirection = new ConcatVector(size);
            for (int j = 0; j < size; ++j) {
                double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length];
                for (int k = 0; k < dense.length; ++k) {
                    dense[k] = (r.nextDouble() - 0.5) * 0.001;
                }
                randomDirection.setDenseComponent(j, dense);
            }
            ConcatVector randomPerturbation = finalWeights.deepClone();
            randomPerturbation.addVectorInPlace(randomDirection, 1.0);
            double randomPerturbedLogLikelihood = this.getValueSum(dataset, randomPerturbation, ll, l2regularization);
            if (logLikelihood < randomPerturbedLogLikelihood - 0.001 * Math.max(1.0, Math.abs(logLikelihood))) {
                System.err.println("Thought optimal point was: " + logLikelihood);
                System.err.println("Discovered better point: " + randomPerturbedLogLikelihood);
            }
            Assert.assertTrue((logLikelihood >= randomPerturbedLogLikelihood - 0.001 * Math.max(1.0, Math.abs(logLikelihood)) ? 1 : 0) != 0);
        }
    }

    private <T> double getValueSum(T[] dataset, ConcatVector weights, AbstractDifferentiableFunction<T> fn, double l2regularization) {
        double value = 0.0;
        for (T t : dataset) {
            value += fn.getSummaryForInstance(t, weights, new ConcatVector(0));
        }
        return value / (double)dataset.length - weights.dotProduct(weights) * l2regularization;
    }
}

