/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.tool.augment;

import com.fasterxml.jackson.core.type.TypeReference;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.augment.AugmentedArgumentEvent;
import org.springframework.ai.tool.augment.ToolInputSchemaAugmenter;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.util.Assert;

public class AugmentedToolCallback<T extends Record>
implements ToolCallback {
    private final ToolCallback delegate;
    private final ToolDefinition augmentedToolDefinition;
    private final Class<T> augmentedArgumentsClass;
    private final @Nullable Consumer<AugmentedArgumentEvent<T>> augmentedArgumentsConsumer;
    private final List<ToolInputSchemaAugmenter.AugmentedArgumentType> augmentedArgumentTypes;
    private boolean removeAugmentedArgumentsAfterProcessing = false;

    public AugmentedToolCallback(ToolCallback delegate, Class<T> augmentedArgumentsClass, @Nullable Consumer<AugmentedArgumentEvent<T>> augmentedArgumentsConsumer, boolean removeExtraArgumentsAfterProcessing) {
        Assert.notNull((Object)delegate, (String)"Delegate ToolCallback must not be null");
        Assert.notNull(augmentedArgumentsClass, (String)"Argument types must not be null");
        Assert.isTrue((boolean)augmentedArgumentsClass.isRecord(), (String)"Argument types must be a Record type");
        Assert.isTrue((augmentedArgumentsClass.getRecordComponents().length > 0 ? 1 : 0) != 0, (String)"Argument types must have at least one field");
        this.delegate = delegate;
        this.augmentedArgumentTypes = ToolInputSchemaAugmenter.toAugmentedArgumentTypes(augmentedArgumentsClass);
        String originalSchema = this.delegate.getToolDefinition().inputSchema();
        String augmentedSchema = ToolInputSchemaAugmenter.augmentToolInputSchema(originalSchema, this.augmentedArgumentTypes);
        this.augmentedToolDefinition = ToolDefinition.builder().name(this.delegate.getToolDefinition().name()).description(this.delegate.getToolDefinition().description()).inputSchema(augmentedSchema).build();
        this.augmentedArgumentsClass = augmentedArgumentsClass;
        this.augmentedArgumentsConsumer = augmentedArgumentsConsumer;
        this.removeAugmentedArgumentsAfterProcessing = removeExtraArgumentsAfterProcessing;
    }

    @Override
    public ToolDefinition getToolDefinition() {
        return this.augmentedToolDefinition;
    }

    @Override
    public String call(String toolInput) {
        return this.delegate.call(this.handleAugmentedArguments(toolInput));
    }

    @Override
    public String call(String toolInput, @Nullable ToolContext tooContext) {
        return this.delegate.call(this.handleAugmentedArguments(toolInput), tooContext);
    }

    private String handleAugmentedArguments(String toolInput) {
        if (this.augmentedArgumentsConsumer != null) {
            Record augmentedArguments = (Record)JsonParser.fromJson(toolInput, this.augmentedArgumentsClass);
            this.augmentedArgumentsConsumer.accept(new AugmentedArgumentEvent<Record>(this.augmentedToolDefinition, toolInput, augmentedArguments));
        }
        if (this.removeAugmentedArgumentsAfterProcessing) {
            Map<String, Object> args = JsonParser.fromJson(toolInput, new TypeReference<Map<String, Object>>(){});
            for (ToolInputSchemaAugmenter.AugmentedArgumentType newFieldType : this.augmentedArgumentTypes) {
                args.remove(newFieldType.name());
            }
            toolInput = JsonParser.toJson(args);
        }
        return toolInput;
    }
}

