/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
 * the License. A copy of the License is located at
 * 
 * http://aws.amazon.com/apache2.0
 * 
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */

package software.amazon.awssdk.services.bedrock.model;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.Function;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.Mutable;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.LocationTrait;
import software.amazon.awssdk.utils.ToString;
import software.amazon.awssdk.utils.builder.CopyableBuilder;
import software.amazon.awssdk.utils.builder.ToCopyableBuilder;

/**
 * <p>
 * Hyperparameters for controlling the reinforcement fine-tuning training process, including learning settings and
 * evaluation intervals.
 * </p>
 */
@Generated("software.amazon.awssdk:codegen")
public final class RFTHyperParameters implements SdkPojo, Serializable,
        ToCopyableBuilder<RFTHyperParameters.Builder, RFTHyperParameters> {
    private static final SdkField<Integer> EPOCH_COUNT_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("epochCount").getter(getter(RFTHyperParameters::epochCount)).setter(setter(Builder::epochCount))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("epochCount").build()).build();

    private static final SdkField<Integer> BATCH_SIZE_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("batchSize").getter(getter(RFTHyperParameters::batchSize)).setter(setter(Builder::batchSize))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("batchSize").build()).build();

    private static final SdkField<Float> LEARNING_RATE_FIELD = SdkField.<Float> builder(MarshallingType.FLOAT)
            .memberName("learningRate").getter(getter(RFTHyperParameters::learningRate)).setter(setter(Builder::learningRate))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("learningRate").build()).build();

    private static final SdkField<Integer> MAX_PROMPT_LENGTH_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("maxPromptLength").getter(getter(RFTHyperParameters::maxPromptLength))
            .setter(setter(Builder::maxPromptLength))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("maxPromptLength").build()).build();

    private static final SdkField<Integer> TRAINING_SAMPLE_PER_PROMPT_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("trainingSamplePerPrompt").getter(getter(RFTHyperParameters::trainingSamplePerPrompt))
            .setter(setter(Builder::trainingSamplePerPrompt))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("trainingSamplePerPrompt").build())
            .build();

    private static final SdkField<Integer> INFERENCE_MAX_TOKENS_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("inferenceMaxTokens").getter(getter(RFTHyperParameters::inferenceMaxTokens))
            .setter(setter(Builder::inferenceMaxTokens))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("inferenceMaxTokens").build())
            .build();

    private static final SdkField<String> REASONING_EFFORT_FIELD = SdkField.<String> builder(MarshallingType.STRING)
            .memberName("reasoningEffort").getter(getter(RFTHyperParameters::reasoningEffortAsString))
            .setter(setter(Builder::reasoningEffort))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("reasoningEffort").build()).build();

    private static final SdkField<Integer> EVAL_INTERVAL_FIELD = SdkField.<Integer> builder(MarshallingType.INTEGER)
            .memberName("evalInterval").getter(getter(RFTHyperParameters::evalInterval)).setter(setter(Builder::evalInterval))
            .traits(LocationTrait.builder().location(MarshallLocation.PAYLOAD).locationName("evalInterval").build()).build();

    private static final List<SdkField<?>> SDK_FIELDS = Collections.unmodifiableList(Arrays.asList(EPOCH_COUNT_FIELD,
            BATCH_SIZE_FIELD, LEARNING_RATE_FIELD, MAX_PROMPT_LENGTH_FIELD, TRAINING_SAMPLE_PER_PROMPT_FIELD,
            INFERENCE_MAX_TOKENS_FIELD, REASONING_EFFORT_FIELD, EVAL_INTERVAL_FIELD));

    private static final Map<String, SdkField<?>> SDK_NAME_TO_FIELD = memberNameToFieldInitializer();

    private static final long serialVersionUID = 1L;

    private final Integer epochCount;

    private final Integer batchSize;

    private final Float learningRate;

    private final Integer maxPromptLength;

    private final Integer trainingSamplePerPrompt;

    private final Integer inferenceMaxTokens;

    private final String reasoningEffort;

    private final Integer evalInterval;

    private RFTHyperParameters(BuilderImpl builder) {
        this.epochCount = builder.epochCount;
        this.batchSize = builder.batchSize;
        this.learningRate = builder.learningRate;
        this.maxPromptLength = builder.maxPromptLength;
        this.trainingSamplePerPrompt = builder.trainingSamplePerPrompt;
        this.inferenceMaxTokens = builder.inferenceMaxTokens;
        this.reasoningEffort = builder.reasoningEffort;
        this.evalInterval = builder.evalInterval;
    }

    /**
     * <p>
     * Number of training epochs to run during reinforcement fine-tuning. Higher values may improve performance but
     * increase training time.
     * </p>
     * 
     * @return Number of training epochs to run during reinforcement fine-tuning. Higher values may improve performance
     *         but increase training time.
     */
    public final Integer epochCount() {
        return epochCount;
    }

    /**
     * <p>
     * Number of training samples processed in each batch during reinforcement fine-tuning (RFT) training. Larger
     * batches may improve training stability.
     * </p>
     * 
     * @return Number of training samples processed in each batch during reinforcement fine-tuning (RFT) training.
     *         Larger batches may improve training stability.
     */
    public final Integer batchSize() {
        return batchSize;
    }

    /**
     * <p>
     * Learning rate for the reinforcement fine-tuning. Controls how quickly the model adapts to reward signals.
     * </p>
     * 
     * @return Learning rate for the reinforcement fine-tuning. Controls how quickly the model adapts to reward signals.
     */
    public final Float learningRate() {
        return learningRate;
    }

    /**
     * <p>
     * Maximum length of input prompts during RFT training, measured in tokens. Longer prompts allow more context but
     * increase memory usage and training-time.
     * </p>
     * 
     * @return Maximum length of input prompts during RFT training, measured in tokens. Longer prompts allow more
     *         context but increase memory usage and training-time.
     */
    public final Integer maxPromptLength() {
        return maxPromptLength;
    }

    /**
     * <p>
     * Number of response samples generated per prompt during RFT training. More samples provide better reward signal
     * estimation.
     * </p>
     * 
     * @return Number of response samples generated per prompt during RFT training. More samples provide better reward
     *         signal estimation.
     */
    public final Integer trainingSamplePerPrompt() {
        return trainingSamplePerPrompt;
    }

    /**
     * <p>
     * Maximum number of tokens the model can generate in response to each prompt during RFT training.
     * </p>
     * 
     * @return Maximum number of tokens the model can generate in response to each prompt during RFT training.
     */
    public final Integer inferenceMaxTokens() {
        return inferenceMaxTokens;
    }

    /**
     * <p>
     * Level of reasoning effort applied during RFT training. Higher values may improve response quality but increase
     * training time.
     * </p>
     * <p>
     * If the service returns an enum value that is not available in the current SDK version, {@link #reasoningEffort}
     * will return {@link ReasoningEffort#UNKNOWN_TO_SDK_VERSION}. The raw value returned by the service is available
     * from {@link #reasoningEffortAsString}.
     * </p>
     * 
     * @return Level of reasoning effort applied during RFT training. Higher values may improve response quality but
     *         increase training time.
     * @see ReasoningEffort
     */
    public final ReasoningEffort reasoningEffort() {
        return ReasoningEffort.fromValue(reasoningEffort);
    }

    /**
     * <p>
     * Level of reasoning effort applied during RFT training. Higher values may improve response quality but increase
     * training time.
     * </p>
     * <p>
     * If the service returns an enum value that is not available in the current SDK version, {@link #reasoningEffort}
     * will return {@link ReasoningEffort#UNKNOWN_TO_SDK_VERSION}. The raw value returned by the service is available
     * from {@link #reasoningEffortAsString}.
     * </p>
     * 
     * @return Level of reasoning effort applied during RFT training. Higher values may improve response quality but
     *         increase training time.
     * @see ReasoningEffort
     */
    public final String reasoningEffortAsString() {
        return reasoningEffort;
    }

    /**
     * <p>
     * Interval between evaluation runs during RFT training, measured in training steps. More frequent evaluation
     * provides better monitoring.
     * </p>
     * 
     * @return Interval between evaluation runs during RFT training, measured in training steps. More frequent
     *         evaluation provides better monitoring.
     */
    public final Integer evalInterval() {
        return evalInterval;
    }

    @Override
    public Builder toBuilder() {
        return new BuilderImpl(this);
    }

    public static Builder builder() {
        return new BuilderImpl();
    }

    public static Class<? extends Builder> serializableBuilderClass() {
        return BuilderImpl.class;
    }

    @Override
    public final int hashCode() {
        int hashCode = 1;
        hashCode = 31 * hashCode + Objects.hashCode(epochCount());
        hashCode = 31 * hashCode + Objects.hashCode(batchSize());
        hashCode = 31 * hashCode + Objects.hashCode(learningRate());
        hashCode = 31 * hashCode + Objects.hashCode(maxPromptLength());
        hashCode = 31 * hashCode + Objects.hashCode(trainingSamplePerPrompt());
        hashCode = 31 * hashCode + Objects.hashCode(inferenceMaxTokens());
        hashCode = 31 * hashCode + Objects.hashCode(reasoningEffortAsString());
        hashCode = 31 * hashCode + Objects.hashCode(evalInterval());
        return hashCode;
    }

    @Override
    public final boolean equals(Object obj) {
        return equalsBySdkFields(obj);
    }

    @Override
    public final boolean equalsBySdkFields(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (!(obj instanceof RFTHyperParameters)) {
            return false;
        }
        RFTHyperParameters other = (RFTHyperParameters) obj;
        return Objects.equals(epochCount(), other.epochCount()) && Objects.equals(batchSize(), other.batchSize())
                && Objects.equals(learningRate(), other.learningRate())
                && Objects.equals(maxPromptLength(), other.maxPromptLength())
                && Objects.equals(trainingSamplePerPrompt(), other.trainingSamplePerPrompt())
                && Objects.equals(inferenceMaxTokens(), other.inferenceMaxTokens())
                && Objects.equals(reasoningEffortAsString(), other.reasoningEffortAsString())
                && Objects.equals(evalInterval(), other.evalInterval());
    }

    /**
     * Returns a string representation of this object. This is useful for testing and debugging. Sensitive data will be
     * redacted from this string using a placeholder value.
     */
    @Override
    public final String toString() {
        return ToString.builder("RFTHyperParameters").add("EpochCount", epochCount()).add("BatchSize", batchSize())
                .add("LearningRate", learningRate()).add("MaxPromptLength", maxPromptLength())
                .add("TrainingSamplePerPrompt", trainingSamplePerPrompt()).add("InferenceMaxTokens", inferenceMaxTokens())
                .add("ReasoningEffort", reasoningEffortAsString()).add("EvalInterval", evalInterval()).build();
    }

    public final <T> Optional<T> getValueForField(String fieldName, Class<T> clazz) {
        switch (fieldName) {
        case "epochCount":
            return Optional.ofNullable(clazz.cast(epochCount()));
        case "batchSize":
            return Optional.ofNullable(clazz.cast(batchSize()));
        case "learningRate":
            return Optional.ofNullable(clazz.cast(learningRate()));
        case "maxPromptLength":
            return Optional.ofNullable(clazz.cast(maxPromptLength()));
        case "trainingSamplePerPrompt":
            return Optional.ofNullable(clazz.cast(trainingSamplePerPrompt()));
        case "inferenceMaxTokens":
            return Optional.ofNullable(clazz.cast(inferenceMaxTokens()));
        case "reasoningEffort":
            return Optional.ofNullable(clazz.cast(reasoningEffortAsString()));
        case "evalInterval":
            return Optional.ofNullable(clazz.cast(evalInterval()));
        default:
            return Optional.empty();
        }
    }

    @Override
    public final List<SdkField<?>> sdkFields() {
        return SDK_FIELDS;
    }

    @Override
    public final Map<String, SdkField<?>> sdkFieldNameToField() {
        return SDK_NAME_TO_FIELD;
    }

    private static Map<String, SdkField<?>> memberNameToFieldInitializer() {
        Map<String, SdkField<?>> map = new HashMap<>();
        map.put("epochCount", EPOCH_COUNT_FIELD);
        map.put("batchSize", BATCH_SIZE_FIELD);
        map.put("learningRate", LEARNING_RATE_FIELD);
        map.put("maxPromptLength", MAX_PROMPT_LENGTH_FIELD);
        map.put("trainingSamplePerPrompt", TRAINING_SAMPLE_PER_PROMPT_FIELD);
        map.put("inferenceMaxTokens", INFERENCE_MAX_TOKENS_FIELD);
        map.put("reasoningEffort", REASONING_EFFORT_FIELD);
        map.put("evalInterval", EVAL_INTERVAL_FIELD);
        return Collections.unmodifiableMap(map);
    }

    private static <T> Function<Object, T> getter(Function<RFTHyperParameters, T> g) {
        return obj -> g.apply((RFTHyperParameters) obj);
    }

    private static <T> BiConsumer<Object, T> setter(BiConsumer<Builder, T> s) {
        return (obj, val) -> s.accept((Builder) obj, val);
    }

    @Mutable
    @NotThreadSafe
    public interface Builder extends SdkPojo, CopyableBuilder<Builder, RFTHyperParameters> {
        /**
         * <p>
         * Number of training epochs to run during reinforcement fine-tuning. Higher values may improve performance but
         * increase training time.
         * </p>
         * 
         * @param epochCount
         *        Number of training epochs to run during reinforcement fine-tuning. Higher values may improve
         *        performance but increase training time.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder epochCount(Integer epochCount);

        /**
         * <p>
         * Number of training samples processed in each batch during reinforcement fine-tuning (RFT) training. Larger
         * batches may improve training stability.
         * </p>
         * 
         * @param batchSize
         *        Number of training samples processed in each batch during reinforcement fine-tuning (RFT) training.
         *        Larger batches may improve training stability.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder batchSize(Integer batchSize);

        /**
         * <p>
         * Learning rate for the reinforcement fine-tuning. Controls how quickly the model adapts to reward signals.
         * </p>
         * 
         * @param learningRate
         *        Learning rate for the reinforcement fine-tuning. Controls how quickly the model adapts to reward
         *        signals.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder learningRate(Float learningRate);

        /**
         * <p>
         * Maximum length of input prompts during RFT training, measured in tokens. Longer prompts allow more context
         * but increase memory usage and training-time.
         * </p>
         * 
         * @param maxPromptLength
         *        Maximum length of input prompts during RFT training, measured in tokens. Longer prompts allow more
         *        context but increase memory usage and training-time.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder maxPromptLength(Integer maxPromptLength);

        /**
         * <p>
         * Number of response samples generated per prompt during RFT training. More samples provide better reward
         * signal estimation.
         * </p>
         * 
         * @param trainingSamplePerPrompt
         *        Number of response samples generated per prompt during RFT training. More samples provide better
         *        reward signal estimation.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder trainingSamplePerPrompt(Integer trainingSamplePerPrompt);

        /**
         * <p>
         * Maximum number of tokens the model can generate in response to each prompt during RFT training.
         * </p>
         * 
         * @param inferenceMaxTokens
         *        Maximum number of tokens the model can generate in response to each prompt during RFT training.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder inferenceMaxTokens(Integer inferenceMaxTokens);

        /**
         * <p>
         * Level of reasoning effort applied during RFT training. Higher values may improve response quality but
         * increase training time.
         * </p>
         * 
         * @param reasoningEffort
         *        Level of reasoning effort applied during RFT training. Higher values may improve response quality but
         *        increase training time.
         * @see ReasoningEffort
         * @return Returns a reference to this object so that method calls can be chained together.
         * @see ReasoningEffort
         */
        Builder reasoningEffort(String reasoningEffort);

        /**
         * <p>
         * Level of reasoning effort applied during RFT training. Higher values may improve response quality but
         * increase training time.
         * </p>
         * 
         * @param reasoningEffort
         *        Level of reasoning effort applied during RFT training. Higher values may improve response quality but
         *        increase training time.
         * @see ReasoningEffort
         * @return Returns a reference to this object so that method calls can be chained together.
         * @see ReasoningEffort
         */
        Builder reasoningEffort(ReasoningEffort reasoningEffort);

        /**
         * <p>
         * Interval between evaluation runs during RFT training, measured in training steps. More frequent evaluation
         * provides better monitoring.
         * </p>
         * 
         * @param evalInterval
         *        Interval between evaluation runs during RFT training, measured in training steps. More frequent
         *        evaluation provides better monitoring.
         * @return Returns a reference to this object so that method calls can be chained together.
         */
        Builder evalInterval(Integer evalInterval);
    }

    static final class BuilderImpl implements Builder {
        private Integer epochCount;

        private Integer batchSize;

        private Float learningRate;

        private Integer maxPromptLength;

        private Integer trainingSamplePerPrompt;

        private Integer inferenceMaxTokens;

        private String reasoningEffort;

        private Integer evalInterval;

        private BuilderImpl() {
        }

        private BuilderImpl(RFTHyperParameters model) {
            epochCount(model.epochCount);
            batchSize(model.batchSize);
            learningRate(model.learningRate);
            maxPromptLength(model.maxPromptLength);
            trainingSamplePerPrompt(model.trainingSamplePerPrompt);
            inferenceMaxTokens(model.inferenceMaxTokens);
            reasoningEffort(model.reasoningEffort);
            evalInterval(model.evalInterval);
        }

        public final Integer getEpochCount() {
            return epochCount;
        }

        public final void setEpochCount(Integer epochCount) {
            this.epochCount = epochCount;
        }

        @Override
        public final Builder epochCount(Integer epochCount) {
            this.epochCount = epochCount;
            return this;
        }

        public final Integer getBatchSize() {
            return batchSize;
        }

        public final void setBatchSize(Integer batchSize) {
            this.batchSize = batchSize;
        }

        @Override
        public final Builder batchSize(Integer batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public final Float getLearningRate() {
            return learningRate;
        }

        public final void setLearningRate(Float learningRate) {
            this.learningRate = learningRate;
        }

        @Override
        public final Builder learningRate(Float learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public final Integer getMaxPromptLength() {
            return maxPromptLength;
        }

        public final void setMaxPromptLength(Integer maxPromptLength) {
            this.maxPromptLength = maxPromptLength;
        }

        @Override
        public final Builder maxPromptLength(Integer maxPromptLength) {
            this.maxPromptLength = maxPromptLength;
            return this;
        }

        public final Integer getTrainingSamplePerPrompt() {
            return trainingSamplePerPrompt;
        }

        public final void setTrainingSamplePerPrompt(Integer trainingSamplePerPrompt) {
            this.trainingSamplePerPrompt = trainingSamplePerPrompt;
        }

        @Override
        public final Builder trainingSamplePerPrompt(Integer trainingSamplePerPrompt) {
            this.trainingSamplePerPrompt = trainingSamplePerPrompt;
            return this;
        }

        public final Integer getInferenceMaxTokens() {
            return inferenceMaxTokens;
        }

        public final void setInferenceMaxTokens(Integer inferenceMaxTokens) {
            this.inferenceMaxTokens = inferenceMaxTokens;
        }

        @Override
        public final Builder inferenceMaxTokens(Integer inferenceMaxTokens) {
            this.inferenceMaxTokens = inferenceMaxTokens;
            return this;
        }

        public final String getReasoningEffort() {
            return reasoningEffort;
        }

        public final void setReasoningEffort(String reasoningEffort) {
            this.reasoningEffort = reasoningEffort;
        }

        @Override
        public final Builder reasoningEffort(String reasoningEffort) {
            this.reasoningEffort = reasoningEffort;
            return this;
        }

        @Override
        public final Builder reasoningEffort(ReasoningEffort reasoningEffort) {
            this.reasoningEffort(reasoningEffort == null ? null : reasoningEffort.toString());
            return this;
        }

        public final Integer getEvalInterval() {
            return evalInterval;
        }

        public final void setEvalInterval(Integer evalInterval) {
            this.evalInterval = evalInterval;
        }

        @Override
        public final Builder evalInterval(Integer evalInterval) {
            this.evalInterval = evalInterval;
            return this;
        }

        @Override
        public RFTHyperParameters build() {
            return new RFTHyperParameters(this);
        }

        @Override
        public List<SdkField<?>> sdkFields() {
            return SDK_FIELDS;
        }

        @Override
        public Map<String, SdkField<?>> sdkFieldNameToField() {
            return SDK_NAME_TO_FIELD;
        }
    }
}
