/*
 * Decompiled with CFR 0.152.
 */
package software.amazon.smithy.aws.traits.clientendpointdiscovery;

import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientDiscoveredEndpointTrait;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientEndpointDiscoveryIdTrait;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientEndpointDiscoveryIndex;
import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientEndpointDiscoveryTrait;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.ToShapeId;
import software.amazon.smithy.model.traits.ErrorTrait;
import software.amazon.smithy.model.transform.ModelTransformer;
import software.amazon.smithy.model.transform.ModelTransformerPlugin;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
public final class CleanClientDiscoveryTraitTransformer
implements ModelTransformerPlugin {
    public Model onRemove(ModelTransformer transformer, Collection<Shape> shapes, Model model) {
        Set<ShapeId> removedOperations = shapes.stream().filter(Shape::isOperationShape).map(Shape::getId).collect(Collectors.toSet());
        Set<ShapeId> removedErrors = shapes.stream().filter(shape -> shape.hasTrait(ErrorTrait.class)).map(Shape::getId).collect(Collectors.toSet());
        Set<Shape> servicesToUpdate = this.getServicesToUpdate(model, removedOperations, removedErrors);
        HashSet<Shape> shapesToUpdate = new HashSet<Shape>(servicesToUpdate);
        Set<Shape> operationsToUpdate = this.getOperationsToUpdate(model, servicesToUpdate.stream().map(Shape::getId).collect(Collectors.toSet()));
        shapesToUpdate.addAll(operationsToUpdate);
        Set<Shape> membersToUpdate = this.getMembersToUpdate(model, operationsToUpdate.stream().map(Shape::getId).collect(Collectors.toSet()));
        shapesToUpdate.addAll(membersToUpdate);
        return transformer.replaceShapes(model, shapesToUpdate);
    }

    private Set<Shape> getServicesToUpdate(Model model, Set<ShapeId> removedOperations, Set<ShapeId> removedErrors) {
        HashSet<Shape> result = new HashSet<Shape>();
        for (ServiceShape service : model.getServiceShapesWithTrait(ClientEndpointDiscoveryTrait.class)) {
            ClientEndpointDiscoveryTrait trait = (ClientEndpointDiscoveryTrait)service.expectTrait(ClientEndpointDiscoveryTrait.class);
            if (!removedOperations.contains(trait.getOperation()) && !removedErrors.contains(trait.getError())) continue;
            ServiceShape.Builder builder = service.toBuilder();
            builder.removeTrait(ClientEndpointDiscoveryTrait.ID);
            result.add((Shape)builder.build());
        }
        return result;
    }

    private Set<Shape> getOperationsToUpdate(Model model, Set<ShapeId> updatedServices) {
        ClientEndpointDiscoveryIndex discoveryIndex = ClientEndpointDiscoveryIndex.of(model);
        Set stillBoundOperations = model.shapes(ServiceShape.class).filter(service -> service.hasTrait(ClientEndpointDiscoveryTrait.class)).map(Shape::getId).filter(service -> !updatedServices.contains(service)).flatMap(service -> discoveryIndex.getEndpointDiscoveryOperations((ToShapeId)service).stream()).collect(Collectors.toSet());
        HashSet<Shape> result = new HashSet<Shape>();
        for (OperationShape operation : model.getOperationShapesWithTrait(ClientDiscoveredEndpointTrait.class)) {
            ClientDiscoveredEndpointTrait trait = (ClientDiscoveredEndpointTrait)operation.expectTrait(ClientDiscoveredEndpointTrait.class);
            if (trait.isRequired() || stillBoundOperations.contains(operation.getId())) continue;
            result.add((Shape)((OperationShape.Builder)operation.toBuilder().removeTrait(ClientDiscoveredEndpointTrait.ID)).build());
        }
        return result;
    }

    private Set<Shape> getMembersToUpdate(Model model, Set<ShapeId> updatedOperations) {
        Set stillBoundMembers = model.shapes(OperationShape.class).filter(operation -> operation.hasTrait(ClientDiscoveredEndpointTrait.class)).filter(operation -> !updatedOperations.contains(operation.getId())).filter(operation -> operation.getInput().isPresent()).map(operation -> model.getShape((ShapeId)operation.getInput().get()).flatMap(Shape::asStructureShape)).filter(Optional::isPresent).flatMap(input -> ((StructureShape)input.get()).getAllMembers().values().stream()).map(Shape::getId).collect(Collectors.toSet());
        return model.shapes(MemberShape.class).filter(member -> member.hasTrait(ClientEndpointDiscoveryIdTrait.class)).filter(member -> !stillBoundMembers.contains(member.getId())).map(member -> ((MemberShape.Builder)member.toBuilder().removeTrait(ClientEndpointDiscoveryIdTrait.ID)).build()).collect(Collectors.toSet());
    }
}

