/*
 * Decompiled with CFR 0.152.
 */
package tech.picnic.errorprone.bugpatterns;

import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.errorprone.BugPattern;
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker;
import com.google.errorprone.fixes.Fix;
import com.google.errorprone.fixes.SuggestedFix;
import com.google.errorprone.fixes.SuggestedFixes;
import com.google.errorprone.matchers.ChildMultiMatcher;
import com.google.errorprone.matchers.Description;
import com.google.errorprone.matchers.Matcher;
import com.google.errorprone.matchers.Matchers;
import com.google.errorprone.suppliers.Supplier;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.AnnotationTree;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.LambdaExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.NewArrayTree;
import com.sun.source.tree.ReturnTree;
import com.sun.source.tree.Tree;
import com.sun.source.util.TreeScanner;
import com.sun.tools.javac.code.Type;
import java.io.Serializable;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.jspecify.annotations.Nullable;
import tech.picnic.errorprone.utils.MoreJUnitMatchers;
import tech.picnic.errorprone.utils.SourceCode;

@BugPattern(summary="Prefer `@ValueSource` over a `@MethodSource` where possible and reasonable", linkType=BugPattern.LinkType.CUSTOM, link="https://error-prone.picnic.tech/bugpatterns/JUnitValueSource", severity=BugPattern.SeverityLevel.SUGGESTION, tags={"Simplification"})
@AutoService(value={BugChecker.class})
public final class JUnitValueSource
extends BugChecker
implements BugChecker.MethodTreeMatcher {
    private static final long serialVersionUID = 1L;
    private static final Matcher<ExpressionTree> SUPPORTED_VALUE_FACTORY_VALUES = Matchers.anyOf((Matcher[])new Matcher[]{JUnitValueSource.isArrayArgumentValueCandidate(), Matchers.toType(MethodInvocationTree.class, (Matcher)Matchers.allOf((Matcher[])new Matcher[]{Matchers.staticMethod().onClass("org.junit.jupiter.params.provider.Arguments").namedAnyOf(new String[]{"arguments", "of"}), Matchers.argumentCount((int)1), Matchers.argument((int)0, JUnitValueSource.isArrayArgumentValueCandidate())}))});
    private static final Matcher<ExpressionTree> ARRAY_OF_SUPPORTED_SINGLE_VALUE_ARGUMENTS = JUnitValueSource.isSingleDimensionArrayCreationWithAllElementsMatching(SUPPORTED_VALUE_FACTORY_VALUES);
    private static final Matcher<ExpressionTree> ENUMERATION_OF_SUPPORTED_SINGLE_VALUE_ARGUMENTS = Matchers.toType(MethodInvocationTree.class, (Matcher)Matchers.allOf((Matcher[])new Matcher[]{Matchers.staticMethod().onClassAny(new String[]{Stream.class.getCanonicalName(), IntStream.class.getCanonicalName(), LongStream.class.getCanonicalName(), DoubleStream.class.getCanonicalName(), List.class.getCanonicalName(), Set.class.getCanonicalName(), ImmutableList.class.getCanonicalName(), ImmutableSet.class.getCanonicalName()}).named("of"), Matchers.hasArguments((ChildMultiMatcher.MatchType)ChildMultiMatcher.MatchType.AT_LEAST_ONE, (Matcher)Matchers.anything()), Matchers.hasArguments((ChildMultiMatcher.MatchType)ChildMultiMatcher.MatchType.ALL, SUPPORTED_VALUE_FACTORY_VALUES)}));
    private static final Matcher<MethodTree> IS_UNARY_METHOD_WITH_SUPPORTED_PARAMETER = Matchers.methodHasParameters((Matcher[])new Matcher[]{Matchers.anyOf((Matcher[])new Matcher[]{Matchers.isPrimitiveOrBoxedPrimitiveType(), Matchers.isSameType(String.class), Matchers.isSameType((Supplier & Serializable)state -> state.getSymtab().classType)})});

    public Description matchMethod(MethodTree tree, VisitorState state) {
        if (!IS_UNARY_METHOD_WITH_SUPPORTED_PARAMETER.matches((Tree)tree, state)) {
            return Description.NO_MATCH;
        }
        Type parameterType = Objects.requireNonNull(ASTHelpers.getType((Tree)((Tree)Iterables.getOnlyElement(tree.getParameters()))), "Missing type for method parameter");
        return JUnitValueSource.findMethodSourceAnnotation(tree, state).flatMap(methodSourceAnnotation -> JUnitValueSource.getSoleLocalFactoryName(methodSourceAnnotation, tree).filter(factory -> !JUnitValueSource.hasSiblingReferencingValueFactory(tree, factory, state)).flatMap(factory -> JUnitValueSource.findSiblingWithName(tree, factory, state)).flatMap(factoryMethod -> JUnitValueSource.tryConstructValueSourceFix(parameterType, methodSourceAnnotation, factoryMethod, state)).map(fix -> this.describeMatch((Tree)methodSourceAnnotation, (Fix)fix))).orElse(Description.NO_MATCH);
    }

    private static Optional<String> getSoleLocalFactoryName(AnnotationTree methodSourceAnnotation, MethodTree method) {
        return JUnitValueSource.getElementIfSingleton(MoreJUnitMatchers.getMethodSourceFactoryNames((AnnotationTree)methodSourceAnnotation, (MethodTree)method)).filter(name -> name.indexOf(35) < 0);
    }

    private static boolean hasSiblingReferencingValueFactory(MethodTree tree, String valueFactory, VisitorState state) {
        return JUnitValueSource.findMatchingSibling(tree, m -> JUnitValueSource.hasValueFactory(m, valueFactory, state), state).isPresent();
    }

    private static Optional<MethodTree> findSiblingWithName(MethodTree tree, String methodName, VisitorState state) {
        return JUnitValueSource.findMatchingSibling(tree, m -> m.getName().contentEquals(methodName), state);
    }

    private static Optional<MethodTree> findMatchingSibling(MethodTree tree, Predicate<? super MethodTree> predicate, VisitorState state) {
        return Objects.requireNonNull((ClassTree)state.findEnclosing(new Class[]{ClassTree.class}), "No class enclosing method").getMembers().stream().filter(MethodTree.class::isInstance).map(MethodTree.class::cast).filter(Predicate.not((Predicate<MethodTree>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Z, equals(java.lang.Object ), (Lcom/sun/source/tree/MethodTree;)Z)((MethodTree)tree))).filter(predicate).findFirst();
    }

    private static boolean hasValueFactory(MethodTree tree, String valueFactoryMethodName, VisitorState state) {
        return JUnitValueSource.findMethodSourceAnnotation(tree, state).stream().anyMatch(annotation -> MoreJUnitMatchers.getMethodSourceFactoryNames((AnnotationTree)annotation, (MethodTree)tree).contains((Object)valueFactoryMethodName));
    }

    private static Optional<AnnotationTree> findMethodSourceAnnotation(MethodTree tree, VisitorState state) {
        return MoreJUnitMatchers.HAS_METHOD_SOURCE.multiMatchResult((Tree)tree, state).matchingNodes().stream().findFirst();
    }

    private static Optional<SuggestedFix> tryConstructValueSourceFix(Type parameterType, AnnotationTree methodSourceAnnotation, MethodTree valueFactoryMethod, VisitorState state) {
        return JUnitValueSource.getSingleReturnExpression(valueFactoryMethod).flatMap(expression -> JUnitValueSource.tryExtractValueSourceAttributeValue(expression, state)).map(valueSourceAttributeValue -> {
            SuggestedFix.Builder fix = SuggestedFix.builder();
            String valueSource = SuggestedFixes.qualifyType((VisitorState)state, (SuggestedFix.Builder)fix, (String)"org.junit.jupiter.params.provider.ValueSource");
            return fix.replace((Tree)methodSourceAnnotation, String.format("@%s(%s = %s)", valueSource, JUnitValueSource.toValueSourceAttributeName(parameterType), valueSourceAttributeValue)).delete((Tree)valueFactoryMethod).build();
        });
    }

    private static Optional<ExpressionTree> getSingleReturnExpression(MethodTree methodTree) {
        final ArrayList returnExpressions = new ArrayList();
        new TreeScanner<Void, Void>(){

            @Override
            public @Nullable Void visitClass(ClassTree node, @Nullable Void unused) {
                return null;
            }

            @Override
            public @Nullable Void visitReturn(ReturnTree node, @Nullable Void unused) {
                returnExpressions.add(node.getExpression());
                return (Void)super.visitReturn(node, null);
            }

            @Override
            public @Nullable Void visitLambdaExpression(LambdaExpressionTree node, @Nullable Void unused) {
                return null;
            }
        }.scan(methodTree, null);
        return JUnitValueSource.getElementIfSingleton(returnExpressions);
    }

    private static Optional<String> tryExtractValueSourceAttributeValue(ExpressionTree tree, VisitorState state) {
        List<? extends ExpressionTree> arguments;
        if (ENUMERATION_OF_SUPPORTED_SINGLE_VALUE_ARGUMENTS.matches((Tree)tree, state)) {
            arguments = ((MethodInvocationTree)tree).getArguments();
        } else if (ARRAY_OF_SUPPORTED_SINGLE_VALUE_ARGUMENTS.matches((Tree)tree, state)) {
            arguments = ((NewArrayTree)tree).getInitializers();
        } else {
            return Optional.empty();
        }
        return Optional.of(arguments.stream().map(arg -> {
            ExpressionTree expressionTree;
            if (arg instanceof MethodInvocationTree) {
                MethodInvocationTree methodInvocation = (MethodInvocationTree)arg;
                expressionTree = (ExpressionTree)Iterables.getOnlyElement(methodInvocation.getArguments());
            } else {
                expressionTree = arg;
            }
            return expressionTree;
        }).map(argument -> SourceCode.treeToString((Tree)argument, (VisitorState)state)).collect(Collectors.joining(", "))).map(value -> arguments.size() > 1 ? String.format("{%s}", value) : value);
    }

    private static String toValueSourceAttributeName(Type type) {
        String typeString;
        return switch (typeString = type.tsym.name.toString()) {
            case "Class" -> "classes";
            case "Character" -> "chars";
            case "Integer" -> "ints";
            default -> typeString.toLowerCase(Locale.ROOT) + "s";
        };
    }

    private static <T> Optional<T> getElementIfSingleton(Collection<T> collection) {
        return Optional.of(collection).filter(elements -> elements.size() == 1).map(Iterables::getOnlyElement);
    }

    private static Matcher<ExpressionTree> isSingleDimensionArrayCreationWithAllElementsMatching(Matcher<? super ExpressionTree> elementMatcher) {
        return (Matcher & Serializable)(tree, state) -> {
            if (!(tree instanceof NewArrayTree)) {
                return false;
            }
            NewArrayTree newArray = (NewArrayTree)tree;
            return newArray.getDimensions().isEmpty() && !newArray.getInitializers().isEmpty() && newArray.getInitializers().stream().allMatch(element -> elementMatcher.matches((Tree)element, state));
        };
    }

    private static Matcher<ExpressionTree> isArrayArgumentValueCandidate() {
        return Matchers.anyOf((Matcher[])new Matcher[]{Matchers.classLiteral((Matcher)Matchers.anything()), (Matcher & Serializable)(tree, state) -> ASTHelpers.constValue((Tree)tree) != null});
    }
}

