/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

public class NeuralNetworkFunctions {
    public static DoubleFunction derivativeIdentityFunction = new DoubleFunction(){

        @Override
        public double apply(double x) {
            return 1.0;
        }
    };
    public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction(){

        @Override
        public double apply(double target, double output) {
            return 2.0 * (output - target);
        }
    };
    public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction(){

        @Override
        public double apply(double target, double output) {
            return -target * Math.log(output) - (1.0 - target) * Math.log(1.0 - output);
        }
    };
    public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction(){

        @Override
        public double apply(double target, double output) {
            double adjustedTarget = target;
            double adjustedActual = output;
            if (adjustedActual == 1.0) {
                adjustedActual = 0.999;
            } else if (output == 0.0) {
                adjustedActual = 0.001;
            }
            if (adjustedTarget == 1.0) {
                adjustedTarget = 0.999;
            } else if (adjustedTarget == 0.0) {
                adjustedTarget = 0.001;
            }
            return -adjustedTarget / adjustedActual + (1.0 - adjustedTarget) / (1.0 - adjustedActual);
        }
    };

    public static DoubleFunction getDoubleFunction(String function) {
        if (function.equalsIgnoreCase("Identity")) {
            return Functions.IDENTITY;
        }
        if (function.equalsIgnoreCase("Sigmoid")) {
            return Functions.SIGMOID;
        }
        throw new IllegalArgumentException("Function not supported.");
    }

    public static DoubleFunction getDerivativeDoubleFunction(String function) {
        if (function.equalsIgnoreCase("Identity")) {
            return derivativeIdentityFunction;
        }
        if (function.equalsIgnoreCase("Sigmoid")) {
            return Functions.SIGMOIDGRADIENT;
        }
        throw new IllegalArgumentException("Function not supported.");
    }

    public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
        if (function.equalsIgnoreCase("Minus_Squared")) {
            return Functions.MINUS_SQUARED;
        }
        if (function.equalsIgnoreCase("Cross_Entropy")) {
            return derivativeCrossEntropy;
        }
        throw new IllegalArgumentException("Function not supported.");
    }

    public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) {
        if (function.equalsIgnoreCase("Minus_Squared")) {
            return derivativeMinusSquared;
        }
        if (function.equalsIgnoreCase("Cross_Entropy")) {
            return derivativeCrossEntropy;
        }
        throw new IllegalArgumentException("Function not supported.");
    }
}

