/*
 * Decompiled with CFR 0.152.
 */
package org.openrewrite.java.cleanup;

import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.Flag;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.tree.TypeTree;
import org.openrewrite.java.tree.TypeUtils;

public class ReplaceLambdaWithMethodReference
extends Recipe {
    public String getDisplayName() {
        return "Use method references in lambda";
    }

    public String getDescription() {
        return "Replaces the single statement lambdas `o -> o instanceOf X`, `o -> (A) o`, `o -> System.out.println(o)`, `o -> o != null`, `o -> o == null` with the equivalent method reference.";
    }

    public Set<String> getTags() {
        return Collections.singleton("RSPEC-1612");
    }

    public Duration getEstimatedEffortPerOccurrence() {
        return Duration.ofMinutes(2L);
    }

    public JavaVisitor<ExecutionContext> getVisitor() {
        return new JavaVisitor<ExecutionContext>(){

            @Override
            public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
                TypeTree tree;
                J.ControlParentheses<TypeTree> j;
                J.Lambda l = (J.Lambda)super.visitLambda(lambda, executionContext);
                if (TypeUtils.isOfClassType(lambda.getType(), "groovy.lang.Closure")) {
                    return l;
                }
                String code = "";
                J body = l.getBody();
                if (body instanceof J.Block && ((J.Block)body).getStatements().size() == 1) {
                    Statement statement = ((J.Block)body).getStatements().get(0);
                    if (statement instanceof J.MethodInvocation) {
                        body = statement;
                    } else if (statement instanceof J.Return && ((J.Return)statement).getExpression() instanceof MethodCall) {
                        body = ((J.Return)statement).getExpression();
                    }
                } else if (body instanceof J.InstanceOf) {
                    J.InstanceOf instanceOf = (J.InstanceOf)body;
                    J j2 = instanceOf.getClazz();
                    if (j2 instanceof J.Identifier && instanceOf.getExpression() instanceof J.Identifier) {
                        body = j2;
                        code = "#{}.class::isInstance";
                    }
                } else if (body instanceof J.TypeCast && !(((J.TypeCast)body).getExpression() instanceof J.MethodInvocation) && (j = ((J.TypeCast)body).getClazz()) != null && (tree = j.getTree()) instanceof J.Identifier && !(j.getType() instanceof JavaType.GenericTypeVariable)) {
                    body = tree;
                    code = "#{}.class::cast";
                }
                if (body instanceof J.Identifier && !code.isEmpty()) {
                    J.Identifier identifier = (J.Identifier)body;
                    JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(identifier.getType());
                    String stub = fullyQualified == null ? "" : "package " + fullyQualified.getPackageName() + "; public class " + fullyQualified.getClassName();
                    JavaTemplate template = JavaTemplate.builder(() -> (this).getCursor(), code).javaParser((JavaParser.Builder<?, ?>)((Object)JavaParser.fromJavaVersion().dependsOn(stub))).imports(fullyQualified == null ? "" : fullyQualified.getFullyQualifiedName()).build();
                    return l.withTemplate(template, l.getCoordinates().replace(), identifier.getSimpleName());
                }
                if (body instanceof J.Binary) {
                    J.Binary binary = (J.Binary)body;
                    if (this.isNullCheck(binary.getLeft(), binary.getRight()) || this.isNullCheck(binary.getRight(), binary.getLeft())) {
                        this.maybeAddImport("java.util.Objects");
                        code = J.Binary.Type.Equal.equals((Object)binary.getOperator()) ? "Objects::isNull" : "Objects::nonNull";
                        return l.withTemplate(JavaTemplate.builder(() -> (this).getCursor(), code).imports("java.util.Objects").build(), l.getCoordinates().replace(), new Object[0]);
                    }
                } else if (body instanceof MethodCall) {
                    MethodCall method = (MethodCall)body;
                    if (method instanceof J.NewClass) {
                        J.NewClass nc = (J.NewClass)method;
                        if (nc.getBody() != null) {
                            return l;
                        }
                        if (ReplaceLambdaWithMethodReference.isAMethodInvocationArgument(l, this.getCursor()) && nc.getType() instanceof JavaType.Class) {
                            boolean hasMultipleConstructors;
                            JavaType.Class clazz = (JavaType.Class)nc.getType();
                            boolean bl = hasMultipleConstructors = clazz.getMethods().stream().filter(JavaType.Method::isConstructor).count() > 1L;
                            if (hasMultipleConstructors) {
                                return l;
                            }
                        }
                    }
                    if (this.multipleMethodInvocations(method) || !this.methodArgumentsMatchLambdaParameters(method, lambda)) {
                        return l;
                    }
                    Expression select = method instanceof J.MethodInvocation ? ((J.MethodInvocation)method).getSelect() : null;
                    JavaType.Method methodType = method.getMethodType();
                    if (methodType != null) {
                        JavaType.FullyQualified declaringType = methodType.getDeclaringType();
                        if (methodType.hasFlags(Flag.Static) || this.methodSelectMatchesFirstLambdaParameter(method, lambda)) {
                            this.maybeAddImport(declaringType);
                            return l.withTemplate(JavaTemplate.builder(() -> (this).getCursor(), "#{}::#{}").imports(declaringType.getFullyQualifiedName()).build(), l.getCoordinates().replace(), declaringType.getClassName(), method.getMethodType().getName());
                        }
                        if (method instanceof J.NewClass) {
                            return l.withTemplate(JavaTemplate.builder(() -> (this).getCursor(), "#{}::new").build(), l.getCoordinates().replace(), this.className((J.NewClass)method));
                        }
                        String templ = select == null ? "#{}::#{}" : "#{any(" + declaringType.getFullyQualifiedName() + ")}::#{}";
                        return l.withTemplate(JavaTemplate.builder(() -> (this).getCursor(), templ).build(), l.getCoordinates().replace(), select == null ? "this" : select, method.getMethodType().getName());
                    }
                }
                return l;
            }

            private String className(J.NewClass method) {
                TypeTree clazz = method.getClazz();
                return clazz instanceof J.ParameterizedType ? ((J.ParameterizedType)clazz).getClazz().toString() : Objects.toString(clazz);
            }

            private boolean multipleMethodInvocations(MethodCall method) {
                return method instanceof J.MethodInvocation && ((J.MethodInvocation)method).getSelect() instanceof J.MethodInvocation;
            }

            private boolean methodArgumentsMatchLambdaParameters(MethodCall method, J.Lambda lambda) {
                JavaType.Method methodType = method.getMethodType();
                if (methodType == null) {
                    return false;
                }
                boolean statik = methodType.hasFlags(Flag.Static);
                List methodArgs = method.getArguments().stream().filter(a -> !(a instanceof J.Empty)).collect(Collectors.toList());
                List lambdaParameters = lambda.getParameters().getParameters().stream().filter(J.VariableDeclarations.class::isInstance).map(J.VariableDeclarations.class::cast).map(v -> v.getVariables().get(0)).collect(Collectors.toList());
                if (methodArgs.isEmpty() && lambdaParameters.isEmpty()) {
                    return true;
                }
                if (!statik && this.methodSelectMatchesFirstLambdaParameter(method, lambda)) {
                    methodArgs.add(0, ((J.MethodInvocation)method).getSelect());
                }
                if (methodArgs.size() != lambdaParameters.size()) {
                    return false;
                }
                for (int i = 0; i < lambdaParameters.size(); ++i) {
                    JavaType.Variable lambdaParam = ((J.VariableDeclarations.NamedVariable)lambdaParameters.get(i)).getVariableType();
                    if (!(methodArgs.get(i) instanceof J.Identifier)) {
                        return false;
                    }
                    JavaType.Variable methodArgument = ((J.Identifier)methodArgs.get(i)).getFieldType();
                    if (lambdaParam == methodArgument) continue;
                    return false;
                }
                return true;
            }

            private boolean methodSelectMatchesFirstLambdaParameter(MethodCall method, J.Lambda lambda) {
                if (!(method instanceof J.MethodInvocation && ((J.MethodInvocation)method).getSelect() instanceof J.Identifier && lambda.getParameters().getParameters().get(0) instanceof J.VariableDeclarations)) {
                    return false;
                }
                J.VariableDeclarations firstLambdaParameter = (J.VariableDeclarations)lambda.getParameters().getParameters().get(0);
                return ((J.Identifier)((J.MethodInvocation)method).getSelect()).getFieldType() == firstLambdaParameter.getVariables().get(0).getVariableType();
            }

            private boolean isNullCheck(J j1, J j2) {
                return j1 instanceof J.Identifier && j2 instanceof J.Literal && "null".equals(((J.Literal)j2).getValueSource());
            }
        };
    }

    private static boolean isAMethodInvocationArgument(J.Lambda lambda, Cursor cursor) {
        Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof J.CompilationUnit);
        if (parent.getValue() instanceof J.MethodInvocation) {
            J.MethodInvocation m = (J.MethodInvocation)parent.getValue();
            return m.getArguments().stream().anyMatch(arg -> arg == lambda);
        }
        return false;
    }
}

