package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.function.Supplier;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

/* loaded from: input_file:lib/x-pack-core-7.17.18.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.class */
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
    private static final long SHALLOW_SIZE;
    public static final ParseField NAME;
    public static final ParseField FEATURE_NAMES;
    public static final ParseField TRAINED_MODELS;
    public static final ParseField AGGREGATE_OUTPUT;
    public static final ParseField TARGET_TYPE;
    public static final ParseField CLASSIFICATION_LABELS;
    public static final ParseField CLASSIFICATION_WEIGHTS;
    private static final ObjectParser<Builder, Void> LENIENT_PARSER;
    private static final ObjectParser<Builder, Void> STRICT_PARSER;
    private final List<String> featureNames;
    private final List<TrainedModel> models;
    private final OutputAggregator outputAggregator;
    private final TargetType targetType;
    private final List<String> classificationLabels;
    private final double[] classificationWeights;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:lib/x-pack-core-7.17.18.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble$Builder.class */
    public static class Builder {
        private List<String> featureNames;
        private List<TrainedModel> trainedModels;
        private OutputAggregator outputAggregator;
        private TargetType targetType;
        private List<String> classificationLabels;
        private double[] classificationWeights;
        private boolean modelsAreOrdered;

        private Builder(boolean z) {
            this.outputAggregator = new WeightedSum();
            this.targetType = TargetType.REGRESSION;
            this.modelsAreOrdered = z;
            this.featureNames = Collections.emptyList();
        }

        private static Builder builderForParser() {
            return new Builder(false);
        }

        public Builder() {
            this(true);
        }

        public Builder setFeatureNames(List<String> list) {
            this.featureNames = list;
            return this;
        }

        public Builder setTrainedModels(List<TrainedModel> list) {
            this.trainedModels = list;
            return this;
        }

        public Builder setOutputAggregator(OutputAggregator outputAggregator) {
            this.outputAggregator = (OutputAggregator) ExceptionsHelper.requireNonNull(outputAggregator, Ensemble.AGGREGATE_OUTPUT);
            return this;
        }

        public Builder setTargetType(TargetType targetType) {
            this.targetType = targetType;
            return this;
        }

        public Builder setClassificationLabels(List<String> list) {
            this.classificationLabels = list;
            return this;
        }

        public Builder setClassificationWeights(List<Double> list) {
            this.classificationWeights = list.stream().mapToDouble((v0) -> {
                return v0.doubleValue();
            }).toArray();
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setTargetType(String str) {
            this.targetType = TargetType.fromString(str);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setModelsAreOrdered(boolean z) {
            this.modelsAreOrdered = z;
        }

        public Ensemble build() {
            if (this.modelsAreOrdered || this.trainedModels == null || this.trainedModels.size() <= 1) {
                return new Ensemble(this.featureNames, this.trainedModels, this.outputAggregator, this.targetType, this.classificationLabels, this.classificationWeights);
            }
            throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects", new Object[0]);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static /* synthetic */ Builder access$200() {
            return builderForParser();
        }
    }

    private static ObjectParser<Builder, Void> createParser(boolean z) {
        ObjectParser<Builder, Void> objectParser = new ObjectParser<>(NAME.getPreferredName(), z, (Supplier<Builder>) () -> {
            return Builder.access$200();
        });
        objectParser.declareStringArray((v0, v1) -> {
            v0.setFeatureNames(v1);
        }, FEATURE_NAMES);
        objectParser.declareNamedObjects((v0, v1) -> {
            v0.setTrainedModels(v1);
        }, (xContentParser, r7, str) -> {
            return z ? (TrainedModel) xContentParser.namedObject(LenientlyParsedTrainedModel.class, str, null) : (TrainedModel) xContentParser.namedObject(StrictlyParsedTrainedModel.class, str, null);
        }, builder -> {
            builder.setModelsAreOrdered(true);
        }, TRAINED_MODELS);
        objectParser.declareNamedObject((v0, v1) -> {
            v0.setOutputAggregator(v1);
        }, (xContentParser2, r72, str2) -> {
            return z ? (OutputAggregator) xContentParser2.namedObject(LenientlyParsedOutputAggregator.class, str2, null) : (OutputAggregator) xContentParser2.namedObject(StrictlyParsedOutputAggregator.class, str2, null);
        }, AGGREGATE_OUTPUT);
        objectParser.declareString((obj, str3) -> {
            ((Builder) obj).setTargetType(str3);
        }, TARGET_TYPE);
        objectParser.declareStringArray((v0, v1) -> {
            v0.setClassificationLabels(v1);
        }, CLASSIFICATION_LABELS);
        objectParser.declareDoubleArray((v0, v1) -> {
            v0.setClassificationWeights(v1);
        }, CLASSIFICATION_WEIGHTS);
        return objectParser;
    }

    public static Ensemble fromXContentStrict(XContentParser xContentParser) {
        return STRICT_PARSER.apply2(xContentParser, (XContentParser) null).build();
    }

    public static Ensemble fromXContentLenient(XContentParser xContentParser) {
        return LENIENT_PARSER.apply2(xContentParser, (XContentParser) null).build();
    }

    Ensemble(List<String> list, List<TrainedModel> list2, OutputAggregator outputAggregator, TargetType targetType, @Nullable List<String> list3, @Nullable double[] dArr) {
        this.featureNames = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list, FEATURE_NAMES));
        this.models = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list2, TRAINED_MODELS));
        this.outputAggregator = (OutputAggregator) ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
        this.targetType = (TargetType) ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
        this.classificationLabels = list3 == null ? null : Collections.unmodifiableList(list3);
        this.classificationWeights = dArr == null ? null : Arrays.copyOf(dArr, dArr.length);
    }

    public Ensemble(StreamInput streamInput) throws IOException {
        this.featureNames = Collections.unmodifiableList(streamInput.readStringList());
        this.models = Collections.unmodifiableList(streamInput.readNamedWriteableList(TrainedModel.class));
        this.outputAggregator = (OutputAggregator) streamInput.readNamedWriteable(OutputAggregator.class);
        this.targetType = TargetType.fromStream(streamInput);
        if (streamInput.readBoolean()) {
            this.classificationLabels = streamInput.readStringList();
        } else {
            this.classificationLabels = null;
        }
        if (streamInput.readBoolean()) {
            this.classificationWeights = streamInput.readDoubleArray();
        } else {
            this.classificationWeights = null;
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public TargetType targetType() {
        return this.targetType;
    }

    @Override // org.elasticsearch.common.io.stream.NamedWriteable
    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.common.io.stream.Writeable
    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeStringCollection(this.featureNames);
        streamOutput.writeNamedWriteableList(this.models);
        streamOutput.writeNamedWriteable(this.outputAggregator);
        this.targetType.writeTo(streamOutput);
        streamOutput.writeBoolean(this.classificationLabels != null);
        if (this.classificationLabels != null) {
            streamOutput.writeStringCollection(this.classificationLabels);
        }
        streamOutput.writeBoolean(this.classificationWeights != null);
        if (this.classificationWeights != null) {
            streamOutput.writeDoubleArray(this.classificationWeights);
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.xcontent.ToXContent
    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        if (!this.featureNames.isEmpty()) {
            xContentBuilder.field(FEATURE_NAMES.getPreferredName(), (Iterable<?>) this.featureNames);
        }
        NamedXContentObjectHelper.writeNamedObjects(xContentBuilder, params, true, TRAINED_MODELS.getPreferredName(), this.models);
        NamedXContentObjectHelper.writeNamedObjects(xContentBuilder, params, false, AGGREGATE_OUTPUT.getPreferredName(), Collections.singletonList(this.outputAggregator));
        xContentBuilder.field(TARGET_TYPE.getPreferredName(), this.targetType.toString());
        if (this.classificationLabels != null) {
            xContentBuilder.field(CLASSIFICATION_LABELS.getPreferredName(), (Iterable<?>) this.classificationLabels);
        }
        if (this.classificationWeights != null) {
            xContentBuilder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), this.classificationWeights);
        }
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Ensemble ensemble = (Ensemble) obj;
        return Objects.equals(this.featureNames, ensemble.featureNames) && Objects.equals(this.models, ensemble.models) && Objects.equals(this.targetType, ensemble.targetType) && Objects.equals(this.classificationLabels, ensemble.classificationLabels) && Objects.equals(this.outputAggregator, ensemble.outputAggregator) && Arrays.equals(this.classificationWeights, ensemble.classificationWeights);
    }

    public int hashCode() {
        return Objects.hash(this.featureNames, this.models, this.outputAggregator, this.targetType, this.classificationLabels, Integer.valueOf(Arrays.hashCode(this.classificationWeights)));
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public void validate() {
        if (this.models.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] must not be empty", TRAINED_MODELS.getPreferredName());
        }
        if (!this.outputAggregator.compatibleWith(this.targetType)) {
            throw ExceptionsHelper.badRequestException("aggregate_output [{}] is not compatible with target_type [{}]", this.targetType, this.outputAggregator.getName());
        }
        if (this.outputAggregator.expectedValueSize() != null && this.outputAggregator.expectedValueSize().intValue() != this.models.size()) {
            throw ExceptionsHelper.badRequestException("[{}] expects value array of size [{}] but number of models is [{}]", AGGREGATE_OUTPUT.getPreferredName(), this.outputAggregator.expectedValueSize(), Integer.valueOf(this.models.size()));
        }
        if ((this.classificationLabels != null || this.classificationWeights != null) && this.targetType != TargetType.CLASSIFICATION) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if [classification_labels] or [classification_weights] are provided", new Object[0]);
        }
        if (this.classificationWeights != null && this.classificationLabels != null && this.classificationWeights.length != this.classificationLabels.size()) {
            throw ExceptionsHelper.badRequestException("[classification_weights] and [classification_labels] should be the same length if both are provided", new Object[0]);
        }
        this.models.forEach((v0) -> {
            v0.validate();
        });
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public long estimatedNumOperations() {
        OptionalDouble average = this.models.stream().mapToLong((v0) -> {
            return v0.estimatedNumOperations();
        }).average();
        if ($assertionsDisabled || average.isPresent()) {
            return ((long) Math.ceil(average.getAsDouble())) + (2 * (this.models.size() - 1));
        }
        throw new AssertionError("unexpected null when calculating number of operations");
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.apache.lucene.util.Accountable
    public long ramBytesUsed() {
        long sizeOfCollection = SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(this.featureNames) + RamUsageEstimator.sizeOfCollection(this.classificationLabels) + RamUsageEstimator.sizeOfCollection(this.models);
        if (this.classificationWeights != null) {
            sizeOfCollection += RamUsageEstimator.sizeOf(this.classificationWeights);
        }
        return sizeOfCollection + this.outputAggregator.ramBytesUsed();
    }

    @Override // org.apache.lucene.util.Accountable
    public Collection<Accountable> getChildResources() {
        ArrayList arrayList = new ArrayList(this.models.size() + 1);
        for (TrainedModel trainedModel : this.models) {
            arrayList.add(Accountables.namedAccountable(trainedModel.getName(), trainedModel));
        }
        arrayList.add(Accountables.namedAccountable(this.outputAggregator.getName(), this.outputAggregator));
        return Collections.unmodifiableCollection(arrayList);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public Version getMinimalCompatibilityVersion() {
        return (Version) this.models.stream().map((v0) -> {
            return v0.getMinimalCompatibilityVersion();
        }).max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(Version.V_7_6_0);
    }

    static {
        $assertionsDisabled = !Ensemble.class.desiredAssertionStatus();
        SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class);
        NAME = new ParseField(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble.NAME, new String[0]);
        FEATURE_NAMES = new ParseField("feature_names", new String[0]);
        TRAINED_MODELS = new ParseField("trained_models", new String[0]);
        AGGREGATE_OUTPUT = new ParseField("aggregate_output", new String[0]);
        TARGET_TYPE = new ParseField("target_type", new String[0]);
        CLASSIFICATION_LABELS = new ParseField("classification_labels", new String[0]);
        CLASSIFICATION_WEIGHTS = new ParseField("classification_weights", new String[0]);
        LENIENT_PARSER = createParser(true);
        STRICT_PARSER = createParser(false);
    }
}
