/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.cleanup;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.TypeUtils;

public class ChainStringBuilderAppendCalls
extends Recipe {
    private static final MethodMatcher STRING_BUILDER_APPEND = new MethodMatcher("java.lang.StringBuilder append(String)");
    private static J.Binary additiveBinaryTemplate = null;

    public String getDisplayName() {
        return "Chain `StringBuilder.append()` calls";
    }

    public String getDescription() {
        return "String concatenation within calls to `StringBuilder.append()` causes unnecessary memory allocation. Except for concatenations of String literals, which are joined together at compile time. Replaces inefficient concatenations with chained calls to `StringBuilder.append()`.";
    }

    @Nullable
    public Duration getEstimatedEffortPerOccurrence() {
        return Duration.ofMinutes(2L);
    }

    @Nullable
    protected TreeVisitor<?, ExecutionContext> getSingleSourceApplicableTest() {
        return new UsesMethod<ExecutionContext>(STRING_BUILDER_APPEND);
    }

    protected TreeVisitor<?, ExecutionContext> getVisitor() {
        return new JavaIsoVisitor<ExecutionContext>(){

            @Override
            public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
                if (STRING_BUILDER_APPEND.matches(method)) {
                    List<Expression> arguments = method.getArguments();
                    if (arguments.size() != 1) {
                        return method;
                    }
                    ArrayList<Expression> flattenExpressions = new ArrayList<Expression>();
                    boolean flattenable = ChainStringBuilderAppendCalls.flatAdditiveExpressions(arguments.get(0), flattenExpressions);
                    if (!flattenable) {
                        return method;
                    }
                    if (flattenExpressions.stream().allMatch(exp -> exp instanceof J.Literal)) {
                        return method;
                    }
                    ArrayList<Expression> groups = new ArrayList<Expression>();
                    ArrayList<Expression> group = new ArrayList<Expression>();
                    boolean appendToString = false;
                    for (Expression exp2 : flattenExpressions) {
                        JavaType.FullyQualified fullyQualified;
                        if (appendToString) {
                            if (exp2 instanceof J.Literal && ((J.Literal)exp2).getType() == JavaType.Primitive.String) {
                                group.add(exp2);
                                continue;
                            }
                            ChainStringBuilderAppendCalls.addToGroups(group, groups);
                            groups.add(exp2);
                            continue;
                        }
                        if (exp2 instanceof J.Literal && ((J.Literal)exp2).getType() == JavaType.Primitive.String) {
                            ChainStringBuilderAppendCalls.addToGroups(group, groups);
                            appendToString = true;
                        } else if ((exp2 instanceof J.Identifier || exp2 instanceof J.MethodInvocation) && exp2.getType() != null && (fullyQualified = TypeUtils.asFullyQualified(exp2.getType())) != null && fullyQualified.getFullyQualifiedName().equals("java.lang.String")) {
                            ChainStringBuilderAppendCalls.addToGroups(group, groups);
                            appendToString = true;
                        }
                        group.add(exp2);
                    }
                    ChainStringBuilderAppendCalls.addToGroups(group, groups);
                    MethodCall chainedMethods = method.withArguments((List)Collections.singletonList((Expression)groups.get(0)));
                    for (int i = 1; i < groups.size(); ++i) {
                        chainedMethods = ((J.MethodInvocation)((J.MethodInvocation)chainedMethods).withSelect(chainedMethods).withArguments((List)Collections.singletonList((Expression)groups.get(i)))).withPrefix(Space.EMPTY);
                    }
                    return chainedMethods;
                }
                return method;
            }
        };
    }

    public static J.Binary concatAdditionBinary(Expression left, Expression right) {
        J.Binary b = ChainStringBuilderAppendCalls.getAdditiveBinaryTemplate();
        return b.withPrefix(b.getLeft().getPrefix()).withLeft(left).withRight((Expression)right.withPrefix(Space.build(" " + right.getPrefix().getWhitespace(), Collections.emptyList())));
    }

    public static Expression additiveExpression(Expression ... expressions) {
        Expression expression = null;
        for (Expression element : expressions) {
            if (element == null) continue;
            expression = expression == null ? element : ChainStringBuilderAppendCalls.concatAdditionBinary(expression, element);
        }
        return expression;
    }

    public static Expression additiveExpression(List<Expression> expressions) {
        return ChainStringBuilderAppendCalls.additiveExpression(expressions.toArray(new Expression[0]));
    }

    public static J.Binary getAdditiveBinaryTemplate() {
        if (additiveBinaryTemplate == null) {
            List<J.CompilationUnit> cus = JavaParser.fromJavaVersion().build().parse("class A { void foo() {String s = \"A\" + \"B\";}}");
            additiveBinaryTemplate = (J.Binary)((List)new JavaIsoVisitor<List<J.Binary>>(){

                @Override
                public J.Binary visitBinary(J.Binary binary, List<J.Binary> rets) {
                    rets.add(binary);
                    return binary;
                }
            }.reduce(cus.get(0), new ArrayList(1))).get(0);
        }
        return additiveBinaryTemplate;
    }

    private static void addToGroups(List<Expression> group, List<Expression> groups) {
        if (!group.isEmpty()) {
            groups.add(ChainStringBuilderAppendCalls.additiveExpression(group));
            group.clear();
        }
    }

    public static boolean flatAdditiveExpressions(Expression expression, List<Expression> expressionList) {
        if (expression instanceof J.Binary) {
            J.Binary b = (J.Binary)expression;
            if (b.getOperator() != J.Binary.Type.Addition) {
                return false;
            }
            return ChainStringBuilderAppendCalls.flatAdditiveExpressions(b.getLeft(), expressionList) && ChainStringBuilderAppendCalls.flatAdditiveExpressions(b.getRight(), expressionList);
        }
        if (expression instanceof J.Literal || expression instanceof J.Identifier || expression instanceof J.MethodInvocation) {
            expressionList.add((Expression)expression.withPrefix(Space.EMPTY));
            return true;
        }
        return false;
    }
}

