/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.executor;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.enums.PMML_MODEL;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.tuples.KiePMMLNameOpType;
import org.kie.pmml.compiler.api.dto.CompilationDTO;
import org.kie.pmml.compiler.api.provider.ModelImplementationProvider;
import org.kie.pmml.compiler.api.utils.ModelUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionModelFactory;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RegressionModelImplementationProvider
implements ModelImplementationProvider<RegressionModel, KiePMMLRegressionModel> {
    private static final Logger logger = LoggerFactory.getLogger((String)RegressionModelImplementationProvider.class.getName());
    private static final String INVALID_NORMALIZATION_METHOD = "Invalid Normalization Method ";

    public PMML_MODEL getPMMLModelType() {
        logger.trace("getPMMLModelType");
        return PMML_MODEL.REGRESSION_MODEL;
    }

    public Class<KiePMMLRegressionModel> getKiePMMLModelClass() {
        return KiePMMLRegressionModel.class;
    }

    public Map<String, String> getSourcesMap(CompilationDTO<RegressionModel> compilationDTO) {
        logger.trace("getKiePMMLModelWithSources {} {} {} {}", new Object[]{compilationDTO.getPackageName(), compilationDTO.getFields(), compilationDTO.getModel(), compilationDTO.getPmmlContext()});
        try {
            return KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
        }
        catch (IOException e) {
            throw new KiePMMLException((Throwable)e);
        }
    }

    protected void validate(List<Field<?>> fields, RegressionModel toValidate) {
        if (toValidate.getRegressionTables() == null || toValidate.getRegressionTables().isEmpty()) {
            throw new KiePMMLException("At least one RegressionTable required");
        }
        if (this.isRegression(toValidate)) {
            List targetFields = ModelUtils.getTargetFields(fields, (Model)toValidate);
            this.validateRegression(targetFields, toValidate);
        } else {
            this.validateClassification(fields, toValidate);
        }
    }

    void validateRegression(List<KiePMMLNameOpType> targetFields, RegressionModel toValidate) {
        this.validateRegressionTargetField(targetFields, toValidate);
        if (toValidate.getRegressionTables().size() != 1) {
            throw new KiePMMLException("Expected one RegressionTable, retrieved " + toValidate.getRegressionTables().size());
        }
        this.validateNormalizationMethod(toValidate.getNormalizationMethod());
    }

    void validateNormalizationMethod(RegressionModel.NormalizationMethod toValidate) {
        switch (toValidate) {
            case NONE: 
            case SOFTMAX: 
            case LOGIT: 
            case EXP: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                return;
            }
        }
        throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + toValidate);
    }

    private void validateClassification(List<Field<?>> fields, RegressionModel toValidate) {
        String categoricalTargeName = this.getCategoricalTargetName(fields, toValidate);
        OP_TYPE opType = ModelUtils.getOpType(fields, (Model)toValidate, (String)categoricalTargeName);
        switch (opType) {
            case CATEGORICAL: {
                this.validateClassificationCategorical(fields, toValidate, categoricalTargeName);
                break;
            }
            case ORDINAL: {
                this.validateClassificationOrdinal(toValidate);
                break;
            }
            default: {
                throw new KiePMMLException("Invalid target type " + opType);
            }
        }
    }

    private void validateClassificationCategorical(List<Field<?>> fields, RegressionModel toValidate, String categoricalFieldName) {
        if (this.isBinary(fields, categoricalFieldName)) {
            this.validateClassificationCategoricalBinary(toValidate);
        } else {
            this.validateClassificationCategoricalNotBinary(toValidate);
        }
    }

    private void validateClassificationCategoricalBinary(RegressionModel toValidate) {
        switch (toValidate.getNormalizationMethod()) {
            case NONE: 
            case LOGIT: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                if (toValidate.getRegressionTables().size() != 2) {
                    throw new KiePMMLException("Expected two RegressionTables, retrieved " + toValidate.getRegressionTables().size());
                }
                return;
            }
        }
        throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + toValidate.getNormalizationMethod());
    }

    private void validateClassificationCategoricalNotBinary(RegressionModel toValidate) {
        switch (toValidate.getNormalizationMethod()) {
            case SOFTMAX: 
            case SIMPLEMAX: {
                if (toValidate.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + toValidate.getRegressionTables().size());
                }
                return;
            }
            case NONE: {
                if (toValidate.getRegressionTables().size() < 3) {
                    throw new KiePMMLException("Expected three RegressionTables, retrieved " + toValidate.getRegressionTables().size());
                }
                return;
            }
        }
        throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + toValidate.getNormalizationMethod());
    }

    private void validateClassificationOrdinal(RegressionModel toValidate) {
        switch (toValidate.getNormalizationMethod()) {
            case NONE: 
            case LOGIT: 
            case PROBIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case CAUCHIT: {
                if (toValidate.getRegressionTables().size() < 2) {
                    throw new KiePMMLException("Expected at least two RegressionTables, retrieved " + toValidate.getRegressionTables().size());
                }
                return;
            }
        }
        throw new KiePMMLException(INVALID_NORMALIZATION_METHOD + toValidate.getNormalizationMethod());
    }

    private void validateRegressionTargetField(List<KiePMMLNameOpType> targetFields, RegressionModel toValidate) {
        if (targetFields.size() != 1) {
            throw new KiePMMLException("Expected one target field, retrieved " + targetFields.size());
        }
        if (toValidate.getTargetField() != null && !Objects.equals(toValidate.getTargetField().getValue(), targetFields.get(0).getName())) {
            throw new KiePMMLException(String.format("Not-matching target fields: %s %s", toValidate.getTargetField(), targetFields.get(0).getName()));
        }
    }

    private boolean isRegression(RegressionModel toValidate) {
        return Objects.equals(MiningFunction.REGRESSION, toValidate.getMiningFunction());
    }

    private boolean isBinary(List<Field<?>> fields, String categoricalFieldName) {
        return fields.stream().filter(DataField.class::isInstance).map(DataField.class::cast).filter(dataField -> Objects.equals(dataField.getName().getValue(), categoricalFieldName)).mapToDouble(dataField -> dataField.getValues().size()).findFirst().orElse(0.0) == 2.0;
    }

    private String getCategoricalTargetName(List<Field<?>> fields, RegressionModel toValidate) {
        List targetFields = ModelUtils.getTargetFields(fields, (Model)toValidate);
        List categoricalFields = fields.stream().filter(dataField -> OpType.CATEGORICAL.equals((Object)dataField.getOpType())).map(dataField -> dataField.getName().getValue()).collect(Collectors.toList());
        List categoricalNameTypes = targetFields.stream().filter(targetField -> categoricalFields.contains(targetField.getName())).collect(Collectors.toList());
        if (categoricalNameTypes.size() != 1) {
            throw new KiePMMLException(String.format("Expected exactly one categorical targets, found %s", categoricalNameTypes.size()));
        }
        return ((KiePMMLNameOpType)categoricalNameTypes.get(0)).getName();
    }
}

