/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.vector.internal;

import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
import org.hibernate.vector.internal.VectorArgumentTypeResolver;
import org.hibernate.vector.internal.VectorArgumentValidator;

public class VectorFunctionFactory {
    private final SqmFunctionRegistry functionRegistry;
    private final TypeConfiguration typeConfiguration;
    private final BasicType<Double> doubleType;
    private final BasicType<Integer> integerType;

    public VectorFunctionFactory(FunctionContributions functionContributions) {
        this.functionRegistry = functionContributions.getFunctionRegistry();
        this.typeConfiguration = functionContributions.getTypeConfiguration();
        BasicTypeRegistry basicTypeRegistry = this.typeConfiguration.getBasicTypeRegistry();
        this.doubleType = basicTypeRegistry.resolve(StandardBasicTypes.DOUBLE);
        this.integerType = basicTypeRegistry.resolve(StandardBasicTypes.INTEGER);
    }

    public void cosineDistance(String pattern) {
        this.registerVectorDistanceFunction("cosine_distance", pattern);
    }

    public void euclideanDistance(String pattern) {
        this.registerVectorDistanceFunction("euclidean_distance", pattern);
        this.functionRegistry.registerAlternateKey("l2_distance", "euclidean_distance");
    }

    public void euclideanSquaredDistance(String pattern) {
        this.registerVectorDistanceFunction("euclidean_squared_distance", pattern);
        this.functionRegistry.registerAlternateKey("l2_squared_distance", "euclidean_squared_distance");
    }

    public void l1Distance(String pattern) {
        this.registerVectorDistanceFunction("l1_distance", pattern);
        this.functionRegistry.registerAlternateKey("taxicab_distance", "l1_distance");
    }

    public void innerProduct(String pattern) {
        this.registerVectorDistanceFunction("inner_product", pattern);
    }

    public void negativeInnerProduct(String pattern) {
        this.registerVectorDistanceFunction("negative_inner_product", pattern);
    }

    public void hammingDistance(String pattern) {
        this.registerVectorDistanceFunction("hamming_distance", pattern);
    }

    public void jaccardDistance(String pattern) {
        this.registerVectorDistanceFunction("jaccard_distance", pattern);
    }

    public void vectorDimensions() {
        this.registerNamedVectorFunction("vector_dims", this.integerType, 1);
    }

    public void vectorNorm() {
        this.registerNamedVectorFunction("vector_norm", this.doubleType, 1);
    }

    public void registerVectorDistanceFunction(String functionName, String pattern) {
        this.functionRegistry.patternDescriptorBuilder(functionName, pattern).setArgumentsValidator(StandardArgumentsValidators.composite((ArgumentsValidator[])new ArgumentsValidator[]{StandardArgumentsValidators.exactly((int)2), VectorArgumentValidator.DISTANCE_INSTANCE})).setArgumentTypeResolver(VectorArgumentTypeResolver.DISTANCE_INSTANCE).setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(this.doubleType)).register();
    }

    public void registerNamedVectorFunction(String functionName, BasicType<?> returnType, int argumentCount) {
        this.functionRegistry.namedDescriptorBuilder(functionName).setArgumentsValidator(StandardArgumentsValidators.composite((ArgumentsValidator[])new ArgumentsValidator[]{StandardArgumentsValidators.exactly((int)argumentCount), VectorArgumentValidator.INSTANCE})).setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE).setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)).register();
    }

    public void registerPatternVectorFunction(String functionName, String pattern, BasicType<?> returnType, int argumentCount) {
        this.functionRegistry.patternDescriptorBuilder(functionName, pattern).setArgumentsValidator(StandardArgumentsValidators.composite((ArgumentsValidator[])new ArgumentsValidator[]{StandardArgumentsValidators.exactly((int)argumentCount), VectorArgumentValidator.INSTANCE})).setArgumentTypeResolver(VectorArgumentTypeResolver.INSTANCE).setReturnTypeResolver(StandardFunctionReturnTypeResolvers.invariant(returnType)).register();
    }
}

