/*
 * Decompiled with CFR 0.152.
 */
package org.atmosphere.interceptor;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.atmosphere.cpr.Action;
import org.atmosphere.cpr.AtmosphereConfig;
import org.atmosphere.cpr.AtmosphereInterceptorAdapter;
import org.atmosphere.cpr.AtmosphereResource;
import org.atmosphere.cpr.AtmosphereResourceEvent;
import org.atmosphere.cpr.AtmosphereResourceEventListenerAdapter;
import org.atmosphere.interceptor.InvokationOrder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RateLimitingInterceptor
extends AtmosphereInterceptorAdapter {
    private static final Logger logger = LoggerFactory.getLogger(RateLimitingInterceptor.class);
    private int maxMessages = 100;
    private long windowNanos = 60000000000L;
    private Policy policy = Policy.DROP;
    private final Map<String, TokenBucket> buckets = new ConcurrentHashMap<String, TokenBucket>();
    private final Set<String> registeredListeners = ConcurrentHashMap.newKeySet();
    private final AtomicLong totalDropped = new AtomicLong();
    private final AtomicLong totalDisconnected = new AtomicLong();

    @Override
    public void configure(AtmosphereConfig config) {
        this.maxMessages = Integer.parseInt(config.getInitParameter("org.atmosphere.rateLimit.maxMessages", "100"));
        int windowSeconds = Integer.parseInt(config.getInitParameter("org.atmosphere.rateLimit.windowSeconds", "60"));
        this.windowNanos = (long)windowSeconds * 1000000000L;
        String policyStr = config.getInitParameter("org.atmosphere.rateLimit.policy", "drop");
        this.policy = "disconnect".equalsIgnoreCase(policyStr) ? Policy.DISCONNECT : Policy.DROP;
        logger.info("Rate limiting configured: {} messages/{} seconds, policy={}", new Object[]{this.maxMessages, windowSeconds, this.policy});
    }

    @Override
    public Action inspect(AtmosphereResource r) {
        super.inspect(r);
        String uuid = r.uuid();
        TokenBucket bucket = this.buckets.computeIfAbsent(uuid, k -> new TokenBucket(this.maxMessages, this.windowNanos));
        if (bucket.tryConsume()) {
            return Action.CONTINUE;
        }
        return switch (this.policy.ordinal()) {
            default -> throw new MatchException(null, null);
            case 0 -> {
                Action var4_4;
                this.totalDropped.incrementAndGet();
                logger.debug("Rate limit exceeded for client {}, dropping message", (Object)uuid);
                yield var4_4 = Action.SKIP_ATMOSPHEREHANDLER;
            }
            case 1 -> {
                Action var4_5;
                this.totalDisconnected.incrementAndGet();
                logger.warn("Rate limit exceeded for client {}, disconnecting", (Object)uuid);
                try {
                    r.close();
                }
                catch (Exception e) {
                    logger.debug("Error closing rate-limited resource {}", (Object)uuid, (Object)e);
                }
                yield var4_5 = Action.CANCELLED;
            }
        };
    }

    @Override
    public void postInspect(AtmosphereResource r) {
        final String uuid = r.uuid();
        if (this.buckets.containsKey(uuid) && this.registeredListeners.add(uuid)) {
            r.addEventListener(new AtmosphereResourceEventListenerAdapter(){

                @Override
                public void onDisconnect(AtmosphereResourceEvent event) {
                    RateLimitingInterceptor.this.buckets.remove(uuid);
                    RateLimitingInterceptor.this.registeredListeners.remove(uuid);
                }

                @Override
                public void onClose(AtmosphereResourceEvent event) {
                    RateLimitingInterceptor.this.buckets.remove(uuid);
                    RateLimitingInterceptor.this.registeredListeners.remove(uuid);
                }
            });
        }
    }

    @Override
    public void destroy() {
        this.buckets.clear();
    }

    public long totalDropped() {
        return this.totalDropped.get();
    }

    public long totalDisconnected() {
        return this.totalDisconnected.get();
    }

    public int maxMessages() {
        return this.maxMessages;
    }

    public Policy policy() {
        return this.policy;
    }

    public int trackedClients() {
        return this.buckets.size();
    }

    @Override
    public InvokationOrder.PRIORITY priority() {
        return InvokationOrder.BEFORE_DEFAULT;
    }

    @Override
    public String toString() {
        return "RateLimitingInterceptor{maxMessages=" + this.maxMessages + ", windowNanos=" + this.windowNanos + ", policy=" + String.valueOf((Object)this.policy) + "}";
    }

    public static enum Policy {
        DROP,
        DISCONNECT;

    }

    static final class TokenBucket {
        private final int maxTokens;
        private final long windowNanos;
        private double tokens;
        private long lastRefillNanos;

        TokenBucket(int maxTokens, long windowNanos) {
            this.maxTokens = maxTokens;
            this.windowNanos = windowNanos;
            this.tokens = maxTokens;
            this.lastRefillNanos = System.nanoTime();
        }

        synchronized boolean tryConsume() {
            this.refill();
            if (this.tokens >= 1.0) {
                this.tokens -= 1.0;
                return true;
            }
            return false;
        }

        private void refill() {
            long now = System.nanoTime();
            long elapsed = now - this.lastRefillNanos;
            if (elapsed <= 0L) {
                return;
            }
            double newTokens = (double)elapsed / (double)this.windowNanos * (double)this.maxTokens;
            this.tokens = Math.min((double)this.maxTokens, this.tokens + newTokens);
            this.lastRefillNanos = now;
        }
    }
}

