/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.script;

import java.math.BigInteger;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.knn.common.KNNValidationUtil;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;

public class KNNScoringUtil {
    private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);

    private static void requireEqualDimension(float[] queryVector, float[] inputVector) {
        Objects.requireNonNull(queryVector);
        Objects.requireNonNull(inputVector);
        if (queryVector.length != inputVector.length) {
            String errorMessage = String.format("query vector dimension mismatch. Expected: %d, Given: %d", inputVector.length, queryVector.length);
            throw new IllegalArgumentException(errorMessage);
        }
    }

    private static void requireEqualDimension(byte[] queryVector, byte[] inputVector) {
        Objects.requireNonNull(queryVector);
        Objects.requireNonNull(inputVector);
        if (queryVector.length != inputVector.length) {
            String errorMessage = String.format("query vector dimension mismatch. Expected: %d, Given: %d", inputVector.length, queryVector.length);
            throw new IllegalArgumentException(errorMessage);
        }
    }

    private static void requireNonBinaryType(String spaceName, VectorDataType vectorDataType) {
        if (VectorDataType.BINARY == vectorDataType) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The data type should be either float or byte but got %s", spaceName, vectorDataType.getValue()));
        }
    }

    private static void requireBinaryType(String spaceName, VectorDataType vectorDataType) {
        if (VectorDataType.BINARY != vectorDataType) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The data type should be binary but got %s", spaceName, vectorDataType.getValue()));
        }
    }

    public static float l2Squared(float[] queryVector, float[] inputVector) {
        return VectorUtil.squareDistance((float[])queryVector, (float[])inputVector);
    }

    public static float l2Squared(byte[] queryVector, byte[] inputVector) {
        return VectorUtil.squareDistance((byte[])queryVector, (byte[])inputVector);
    }

    private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
        Objects.requireNonNull(inputVector);
        float[] value = new float[inputVector.size()];
        int index = 0;
        for (Number val : inputVector) {
            float floatValue = val.floatValue();
            if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
                KNNValidationUtil.validateByteVectorValue(floatValue, vectorDataType);
            }
            value[index++] = floatValue;
        }
        return value;
    }

    private static byte[] toByte(List<Number> inputVector, VectorDataType vectorDataType) {
        Objects.requireNonNull(inputVector);
        byte[] value = new byte[inputVector.size()];
        int index = 0;
        for (Number val : inputVector) {
            float floatValue = val.floatValue();
            if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) {
                KNNValidationUtil.validateByteVectorValue(floatValue, vectorDataType);
            }
            value[index++] = val.byteValue();
        }
        return value;
    }

    public static float cosinesimil(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        try {
            return VectorUtil.cosine((float[])queryVector, (float[])inputVector);
        }
        catch (AssertionError | IllegalArgumentException e) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return 0.0f;
        }
    }

    public static float cosinesimil(byte[] queryVector, byte[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        try {
            return VectorUtil.cosine((byte[])queryVector, (byte[])inputVector);
        }
        catch (AssertionError | IllegalArgumentException e) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return 0.0f;
        }
    }

    public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float dotProduct = 0.0f;
        float normInputVector = 0.0f;
        for (int i = 0; i < queryVector.length; ++i) {
            dotProduct += queryVector[i] * inputVector[i];
            normInputVector += inputVector[i] * inputVector[i];
        }
        float normalizedProduct = normQueryVector * normInputVector;
        if (normalizedProduct == 0.0f) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return 0.0f;
        }
        return (float)((double)dotProduct / Math.sqrt(normalizedProduct));
    }

    public static float calculateHammingBit(BigInteger queryBigInteger, BigInteger inputBigInteger) {
        return inputBigInteger.xor(queryBigInteger).bitCount();
    }

    public static float calculateHammingBit(Long queryLong, Long inputLong) {
        return Long.bitCount(queryLong ^ inputLong);
    }

    public static float calculateHammingBit(byte[] queryVector, byte[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        return VectorUtil.xorBitCount((byte[])queryVector, (byte[])inputVector);
    }

    public static float l1Norm(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance += Math.abs(diff);
        }
        return distance;
    }

    public static float l1Norm(byte[] queryVector, byte[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance += Math.abs(diff);
        }
        return distance;
    }

    public static float lInfNorm(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance = Math.max(Math.abs(diff), distance);
        }
        return distance;
    }

    public static float lInfNorm(byte[] queryVector, byte[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance = Math.max(Math.abs(diff), distance);
        }
        return distance;
    }

    public static float innerProduct(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        return VectorUtil.dotProduct((float[])queryVector, (float[])inputVector);
    }

    public static float innerProduct(byte[] queryVector, byte[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        return VectorUtil.dotProduct((byte[])queryVector, (byte[])inputVector);
    }

    public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("l2Squared", vectorDataType);
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNScoringUtil.l2Squared(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), (float[])docValues.getValue());
        }
        return KNNScoringUtil.l2Squared(KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType()), (byte[])docValues.getValue());
    }

    public static float lInfNorm(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("lInfNorm", vectorDataType);
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNScoringUtil.lInfNorm(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), (float[])docValues.getValue());
        }
        return KNNScoringUtil.lInfNorm(KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType()), (byte[])docValues.getValue());
    }

    public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("l1Norm", vectorDataType);
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNScoringUtil.l1Norm(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), (float[])docValues.getValue());
        }
        return KNNScoringUtil.l1Norm(KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType()), (byte[])docValues.getValue());
    }

    public static float innerProduct(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("innerProduct", vectorDataType);
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNScoringUtil.innerProduct(KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType()), (float[])docValues.getValue());
        }
        return KNNScoringUtil.innerProduct(KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType()), (byte[])docValues.getValue());
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("cosineSimilarity", vectorDataType);
        if (VectorDataType.FLOAT == vectorDataType) {
            float[] inputVector = KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType());
            SpaceType.COSINESIMIL.validateVector(inputVector);
            return KNNScoringUtil.cosinesimil(inputVector, (float[])docValues.getValue());
        }
        byte[] inputVector = KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType());
        SpaceType.COSINESIMIL.validateVector(inputVector);
        return KNNScoringUtil.cosinesimil(inputVector, (byte[])docValues.getValue());
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues, Number queryVectorMagnitude) {
        VectorDataType vectorDataType = docValues.getVectorDataType();
        KNNScoringUtil.requireNonBinaryType("cosineSimilarity", vectorDataType);
        float[] inputVector = KNNScoringUtil.toFloat(queryVector, docValues.getVectorDataType());
        SpaceType.COSINESIMIL.validateVector(inputVector);
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNScoringUtil.cosinesimilOptimized(inputVector, (float[])docValues.getValue(), queryVectorMagnitude.floatValue());
        }
        byte[] docVectorInByte = (byte[])docValues.getValue();
        float[] docVectorInFloat = new float[docVectorInByte.length];
        for (int i = 0; i < docVectorInByte.length; ++i) {
            docVectorInFloat[i] = docVectorInByte[i];
        }
        return KNNScoringUtil.cosinesimilOptimized(inputVector, docVectorInFloat, queryVectorMagnitude.floatValue());
    }

    public static float hamming(List<Number> queryVector, KNNVectorScriptDocValues<?> docValues) {
        KNNScoringUtil.requireBinaryType("hamming", docValues.getVectorDataType());
        byte[] queryVectorInByte = KNNScoringUtil.toByte(queryVector, docValues.getVectorDataType());
        return KNNScoringUtil.calculateHammingBit(queryVectorInByte, (byte[])docValues.getValue());
    }
}

