/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.compiler.factories.KiePMMLRegressionTableFactory;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KiePMMLClassificationTableFactory {
    private static final Logger logger = LoggerFactory.getLogger((String)KiePMMLClassificationTableFactory.class.getName());
    static final String KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA = "KiePMMLClassificationTableTemplate.tmpl";
    static final String KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE = "KiePMMLClassificationTableTemplate";
    static final String GETKIEPMML_TABLE = "getKiePMMLTable";
    static final String CATEGORICAL_TABLE_MAP = "categoryTableMap";
    static final ClassOrInterfaceDeclaration CLASSIFICATION_TABLE_TEMPLATE;
    public static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS;
    public static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS;
    private static AtomicInteger classArity;

    private KiePMMLClassificationTableFactory() {
    }

    public static KiePMMLClassificationTable getClassificationTable(RegressionCompilationDTO compilationDTO) {
        logger.trace("getClassificationTable {}", (Object)compilationDTO);
        LinkedHashMap<String, KiePMMLRegressionTable> categoryTableMap = KiePMMLRegressionTableFactory.getRegressionTables(compilationDTO);
        boolean isBinary = compilationDTO.isBinary(categoryTableMap.size());
        SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> probabilityMapFunction = KiePMMLClassificationTableFactory.getProbabilityMapFunction(compilationDTO.getModelNormalizationMethod(), isBinary);
        return (KiePMMLClassificationTable)KiePMMLClassificationTable.builder((String)UUID.randomUUID().toString(), Collections.emptyList()).withRegressionNormalizationMethod(compilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD()).withOpType(compilationDTO.getOP_TYPE()).withCategoryTableMap(categoryTableMap).withProbabilityMapFunction(probabilityMapFunction).withIsBinary(Boolean.valueOf(isBinary)).withTargetField(compilationDTO.getTargetFieldName()).build();
    }

    public static Map<String, KiePMMLTableSourceCategory> getClassificationTableBuilders(RegressionCompilationDTO compilationDTO) {
        logger.trace("getRegressionTables {}", compilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> toReturn = KiePMMLRegressionTableFactory.getRegressionTableBuilders(compilationDTO);
        Map.Entry<String, String> regressionTableEntry = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, toReturn);
        toReturn.put(regressionTableEntry.getKey(), new KiePMMLTableSourceCategory(regressionTableEntry.getValue(), ""));
        return toReturn;
    }

    public static Map.Entry<String, String> getClassificationTableBuilder(RegressionCompilationDTO compilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap) {
        logger.trace("getRegressionTableBuilder {}", regressionTablesMap);
        String className = "KiePMMLClassificationTable" + classArity.addAndGet(1);
        CompilationUnit cloneCU = JavaParserUtils.getKiePMMLModelCompilationUnit((String)className, (String)compilationDTO.getPackageName(), (String)KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA, (String)KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE);
        ClassOrInterfaceDeclaration tableTemplate = (ClassOrInterfaceDeclaration)cloneCU.getClassByName(className).orElseThrow(() -> new KiePMMLException("Main class not found: " + className));
        MethodDeclaration staticGetterMethod = (MethodDeclaration)tableTemplate.getMethodsByName(GETKIEPMML_TABLE).get(0);
        KiePMMLClassificationTableFactory.setStaticGetter(compilationDTO, regressionTablesMap, staticGetterMethod, className.toLowerCase());
        return new AbstractMap.SimpleEntry<String, String>(JavaParserUtils.getFullClassName((CompilationUnit)cloneCU), cloneCU.toString());
    }

    static SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> getProbabilityMapFunction(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return KiePMMLClassificationTableFactory.getProbabilityMapFunctionSupported(normalizationMethod, isBinary);
    }

    static SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> getProbabilityMapFunctionSupported(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        switch (normalizationMethod) {
            case SOFTMAX: {
                return KiePMMLClassificationTable::getSOFTMAXProbabilityMap;
            }
            case SIMPLEMAX: {
                return KiePMMLClassificationTable::getSIMPLEMAXProbabilityMap;
            }
            case NONE: {
                return isBinary ? KiePMMLClassificationTable::getNONEBinaryProbabilityMap : KiePMMLClassificationTable::getNONEProbabilityMap;
            }
            case LOGIT: {
                return KiePMMLClassificationTable::getLOGITProbabilityMap;
            }
            case PROBIT: {
                return KiePMMLClassificationTable::getPROBITProbabilityMap;
            }
            case CLOGLOG: {
                return KiePMMLClassificationTable::getCLOGLOGProbabilityMap;
            }
            case CAUCHIT: {
                return KiePMMLClassificationTable::getCAUCHITProbabilityMap;
            }
        }
        throw new KiePMMLException("Unexpected NormalizationMethod " + normalizationMethod);
    }

    static void setStaticGetter(RegressionCompilationDTO compilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap, MethodDeclaration staticGetterMethod, String variableName) {
        BlockStmt classificationTableBody = (BlockStmt)staticGetterMethod.getBody().orElseThrow(() -> new KiePMMLException(String.format("Missing body in %s", staticGetterMethod)));
        VariableDeclarator variableDeclarator = (VariableDeclarator)CommonCodegenUtils.getVariableDeclarator((BlockStmt)classificationTableBody, (String)"toReturn").orElseThrow(() -> new KiePMMLException(String.format("Missing expected variable '%s' in body %s", "toReturn", classificationTableBody)));
        BlockStmt newBody = new BlockStmt();
        LinkedHashMap regressionTableCategoriesMap = new LinkedHashMap();
        regressionTablesMap.forEach((className, tableSourceCategory) -> {
            MethodCallExpr methodCallExpr = new MethodCallExpr();
            methodCallExpr.setScope((Expression)new NameExpr(className));
            methodCallExpr.setName(GETKIEPMML_TABLE);
            regressionTableCategoriesMap.put(tableSourceCategory.getCategory(), methodCallExpr);
        });
        String categoryTableMapName = String.format("%s_%s", CATEGORICAL_TABLE_MAP, variableName);
        CommonCodegenUtils.createPopulatedLinkedHashMap((BlockStmt)newBody, (String)categoryTableMapName, Arrays.asList(String.class.getSimpleName(), KiePMMLRegressionTable.class.getName()), regressionTableCategoriesMap);
        MethodCallExpr initializer = ((Expression)variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format("Missing '%s' initializer in %s", "toReturn", classificationTableBody)))).asMethodCallExpr();
        MethodCallExpr builder = CommonCodegenUtils.getChainedMethodCallExprFrom((String)"builder", (MethodCallExpr)initializer);
        builder.setArgument(0, (Expression)new StringLiteralExpr(variableName));
        REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod = compilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withRegressionNormalizationMethod", (MethodCallExpr)initializer).setArgument(0, (Expression)new NameExpr(regressionNormalizationMethod.getClass().getSimpleName() + "." + regressionNormalizationMethod.name()));
        OP_TYPE opType = compilationDTO.getOP_TYPE();
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withOpType", (MethodCallExpr)initializer).setArgument(0, (Expression)new NameExpr(opType.getClass().getSimpleName() + "." + opType.name()));
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withCategoryTableMap", (MethodCallExpr)initializer).setArgument(0, (Expression)new NameExpr(categoryTableMapName));
        boolean isBinary = compilationDTO.isBinary(regressionTablesMap.size());
        Expression probabilityMapFunctionExpression = KiePMMLClassificationTableFactory.getProbabilityMapFunctionExpression(compilationDTO.getModelNormalizationMethod(), isBinary);
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withProbabilityMapFunction", (MethodCallExpr)initializer).setArgument(0, probabilityMapFunctionExpression);
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withIsBinary", (MethodCallExpr)initializer).setArgument(0, CommonCodegenUtils.getExpressionForObject((Object)isBinary));
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withTargetField", (MethodCallExpr)initializer).setArgument(0, CommonCodegenUtils.getExpressionForObject((Object)compilationDTO.getTargetFieldName()));
        CommonCodegenUtils.getChainedMethodCallExprFrom((String)"withTargetCategory", (MethodCallExpr)initializer).setArgument(0, CommonCodegenUtils.getExpressionForObject(null));
        classificationTableBody.getStatements().forEach(arg_0 -> ((BlockStmt)newBody).addStatement(arg_0));
        staticGetterMethod.setBody(newBody);
    }

    static Expression getProbabilityMapFunctionExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return KiePMMLClassificationTableFactory.getProbabilityMapFunctionSupportedExpression(normalizationMethod, isBinary);
    }

    static MethodReferenceExpr getProbabilityMapFunctionSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean isBinary) {
        Object normalizationName = normalizationMethod.name();
        if (RegressionModel.NormalizationMethod.NONE.equals((Object)normalizationMethod) && isBinary) {
            normalizationName = (String)normalizationName + "Binary";
        }
        String thisExpressionMethodName = String.format("get%sProbabilityMap", normalizationName);
        CastExpr castExpr = new CastExpr();
        String stringClassName = String.class.getSimpleName();
        String doubleClassName = Double.class.getSimpleName();
        ClassOrInterfaceType linkedHashMapReferenceType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames((String)LinkedHashMap.class.getCanonicalName(), Arrays.asList(stringClassName, doubleClassName));
        ClassOrInterfaceType consumerType = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypes((String)SerializableFunction.class.getCanonicalName(), Arrays.asList(linkedHashMapReferenceType, linkedHashMapReferenceType));
        castExpr.setType((Type)consumerType);
        castExpr.setExpression("KiePMMLClassificationTable");
        MethodReferenceExpr toReturn = new MethodReferenceExpr();
        toReturn.setScope((Expression)castExpr);
        toReturn.setIdentifier(thisExpressionMethodName);
        return toReturn;
    }

    static {
        CompilationUnit cloneCU = JavaParserUtils.getFromFileName((String)KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA);
        CLASSIFICATION_TABLE_TEMPLATE = (ClassOrInterfaceDeclaration)cloneCU.getClassByName(KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE).orElseThrow(() -> new KiePMMLException("Main class not found: KiePMMLClassificationTableTemplate"));
        ((MethodDeclaration)CLASSIFICATION_TABLE_TEMPLATE.getMethodsByName(GETKIEPMML_TABLE).get(0)).clone();
        SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);
        UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.LOGLOG);
        classArity = new AtomicInteger(0);
    }
}

