/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.evaluator;

import java.util.Map;
import org.drools.core.util.StringUtils;
import org.kie.api.KieBase;
import org.kie.api.pmml.PMML4Result;
import org.kie.pmml.commons.enums.ResultCode;
import org.kie.pmml.commons.exceptions.KiePMMLInternalException;
import org.kie.pmml.commons.model.KiePMMLModel;
import org.kie.pmml.commons.model.enums.PMML_MODEL;
import org.kie.pmml.evaluator.api.exceptions.KiePMMLModelException;
import org.kie.pmml.evaluator.api.executor.PMMLContext;
import org.kie.pmml.evaluator.core.executor.PMMLModelEvaluator;
import org.kie.pmml.evaluator.core.utils.Converter;
import org.kie.pmml.models.regression.model.KiePMMLRegressionClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionModel;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;

public class PMMLRegressionModelEvaluator
implements PMMLModelEvaluator {
    private static final String INVALID_NORMALIZATION_METHOD = "Invalid Normalization Method %s";
    private static final String EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED = "Expected at least two RegressionTables, retrieved %s";
    private static final String EXPECTED_TWO_REGRESSION_TABLES_RETRIEVED = "Expected two RegressionTables, retrieved %s";
    private static final String EXPECTED_A_KIE_PMMLREGRESSION_MODEL_RECEIVED = "Expected a KiePMMLRegressionModel, received %s ";
    private static final String TARGET_FIELD_REQUIRED_RETRIEVED = "TargetField required, retrieved %s";
    private static final String INVALID_TARGET_TYPE = "Invalid target type %s";

    public PMML_MODEL getPMMLModelType() {
        return PMML_MODEL.REGRESSION_MODEL;
    }

    public PMML4Result evaluate(KieBase knowledgeBase, KiePMMLModel model, PMMLContext pmmlContext) {
        this.validate(model);
        PMML4Result toReturn = new PMML4Result();
        String targetField = model.getTargetField();
        Map requestData = Converter.getUnwrappedParametersMap((Map)pmmlContext.getRequestData().getMappedRequestParams());
        Object result = model.evaluate((Object)knowledgeBase, requestData);
        toReturn.addResultVariable(targetField, result);
        toReturn.setResultObjectName(targetField);
        toReturn.setResultCode(ResultCode.OK.getName());
        model.getOutputFieldsMap().forEach((arg_0, arg_1) -> ((PMML4Result)toReturn).addResultVariable(arg_0, arg_1));
        return toReturn;
    }

    private void validate(KiePMMLModel toValidate) {
        if (!(toValidate instanceof KiePMMLRegressionModel)) {
            throw new KiePMMLModelException(String.format(EXPECTED_A_KIE_PMMLREGRESSION_MODEL_RECEIVED, toValidate.getClass().getName()));
        }
        if (((KiePMMLRegressionModel)toValidate).getRegressionTable() == null) {
            throw new KiePMMLModelException("At least one RegressionTable required");
        }
        KiePMMLRegressionTable regressionTable = ((KiePMMLRegressionModel)toValidate).getRegressionTable();
        if (regressionTable instanceof KiePMMLRegressionClassificationTable) {
            this.validateClassification((KiePMMLRegressionClassificationTable)regressionTable);
        } else {
            this.validateRegression(regressionTable);
        }
    }

    private void validateRegression(KiePMMLRegressionTable toValidate) {
        if (toValidate.getTargetField() == null || StringUtils.isEmpty((CharSequence)toValidate.getTargetField().trim())) {
            throw new KiePMMLInternalException(String.format(TARGET_FIELD_REQUIRED_RETRIEVED, toValidate.getTargetField()));
        }
    }

    private void validateClassification(KiePMMLRegressionClassificationTable toValidate) {
        switch (toValidate.getOpType()) {
            case CATEGORICAL: {
                this.validateClassificationCategorical(toValidate);
                break;
            }
            case ORDINAL: {
                this.validateClassificationOrdinal(toValidate);
                break;
            }
            default: {
                throw new KiePMMLModelException(String.format(INVALID_TARGET_TYPE, toValidate.getOpType()));
            }
        }
    }

    private void validateClassificationCategorical(KiePMMLRegressionClassificationTable toValidate) {
        if (toValidate.isBinary()) {
            this.validateClassificationCategoricalBinary(toValidate);
        } else {
            this.validateClassificationCategoricalNotBinary(toValidate);
        }
    }

    private void validateClassificationCategoricalBinary(KiePMMLRegressionClassificationTable toValidate) {
        switch (toValidate.getRegressionNormalizationMethod()) {
            case LOGIT: 
            case PROBIT: 
            case CAUCHIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case NONE: {
                if (toValidate.getCategoryTableMap().size() != 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_TWO_REGRESSION_TABLES_RETRIEVED, toValidate.getCategoryTableMap().size()));
                }
                return;
            }
        }
        throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, toValidate.getRegressionNormalizationMethod()));
    }

    private void validateClassificationCategoricalNotBinary(KiePMMLRegressionClassificationTable toValidate) {
        switch (toValidate.getRegressionNormalizationMethod()) {
            case NONE: 
            case SOFTMAX: 
            case SIMPLEMAX: {
                if (toValidate.getCategoryTableMap().size() < 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED, toValidate.getCategoryTableMap().size()));
                }
                return;
            }
        }
        throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, toValidate.getRegressionNormalizationMethod()));
    }

    private void validateClassificationOrdinal(KiePMMLRegressionClassificationTable toValidate) {
        switch (toValidate.getRegressionNormalizationMethod()) {
            case LOGIT: 
            case PROBIT: 
            case CAUCHIT: 
            case CLOGLOG: 
            case LOGLOG: 
            case NONE: {
                if (toValidate.getCategoryTableMap().size() < 2) {
                    throw new KiePMMLModelException(String.format(EXPECTED_AT_LEAST_TWO_REGRESSION_TABLES_RETRIEVED, toValidate.getCategoryTableMap().size()));
                }
                return;
            }
        }
        throw new KiePMMLModelException(String.format(INVALID_NORMALIZATION_METHOD, toValidate.getRegressionNormalizationMethod()));
    }
}

