/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.chat.memory.cassandra;

import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.Statement;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.term.Term;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;

public final class CassandraChatMemory
implements ChatMemory {
    public static final String CONVERSATION_TS = CassandraChatMemory.class.getSimpleName() + "_message_timestamp";
    final CassandraChatMemoryConfig conf;
    private final PreparedStatement addUserStmt;
    private final PreparedStatement addAssistantStmt;
    private final PreparedStatement getStmt;
    private final PreparedStatement deleteStmt;

    public CassandraChatMemory(CassandraChatMemoryConfig config) {
        this.conf = config;
        this.conf.ensureSchemaExists();
        this.addUserStmt = this.prepareAddStmt(this.conf.userColumn);
        this.addAssistantStmt = this.prepareAddStmt(this.conf.assistantColumn);
        this.getStmt = this.prepareGetStatement();
        this.deleteStmt = this.prepareDeleteStmt();
    }

    public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
        return new CassandraChatMemory(conf);
    }

    public void add(String conversationId, List<Message> messages) {
        AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
        messages.forEach(msg -> {
            if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
                msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
            }
            this.add(conversationId, (Message)msg);
        });
    }

    public void add(String sessionId, Message msg) {
        Preconditions.checkArgument((!msg.getMetadata().containsKey(CONVERSATION_TS) || msg.getMetadata().get(CONVERSATION_TS) instanceof Instant ? 1 : 0) != 0, (String)"messages only accept metadata '%s' entries of type Instant", (Object)CONVERSATION_TS);
        msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());
        PreparedStatement stmt = this.getStatement(msg);
        List primaryKeys = (List)this.conf.primaryKeyTranslator.apply(sessionId);
        BoundStatementBuilder builder = stmt.boundStatementBuilder(new Object[0]);
        for (int k = 0; k < primaryKeys.size(); ++k) {
            CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
            builder = (BoundStatementBuilder)builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
        }
        Instant instant = (Instant)msg.getMetadata().get(CONVERSATION_TS);
        builder = (BoundStatementBuilder)((BoundStatementBuilder)builder.setInstant("message_timestamp", instant)).setString("message", msg.getText());
        this.conf.session.execute((Statement)builder.build());
    }

    PreparedStatement getStatement(Message msg) {
        return switch (msg.getMessageType()) {
            case MessageType.USER -> this.addUserStmt;
            case MessageType.ASSISTANT -> this.addAssistantStmt;
            default -> throw new IllegalArgumentException("Cant add type " + String.valueOf(msg));
        };
    }

    public void clear(String sessionId) {
        List primaryKeys = (List)this.conf.primaryKeyTranslator.apply(sessionId);
        BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(new Object[0]);
        for (int k = 0; k < primaryKeys.size(); ++k) {
            CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
            builder = (BoundStatementBuilder)builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
        }
        this.conf.session.execute((Statement)builder.build());
    }

    public List<Message> get(String sessionId, int lastN) {
        List primaryKeys = (List)this.conf.primaryKeyTranslator.apply(sessionId);
        BoundStatementBuilder builder = (BoundStatementBuilder)this.getStmt.boundStatementBuilder(new Object[0]).setInt("lastN", lastN);
        for (int k = 0; k < primaryKeys.size(); ++k) {
            CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
            builder = (BoundStatementBuilder)builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
        }
        ArrayList<Message> messages = new ArrayList<Message>();
        for (Row r : this.conf.session.execute((Statement)builder.build())) {
            String assistant = r.getString(this.conf.assistantColumn);
            String user = r.getString(this.conf.userColumn);
            if (null != assistant) {
                messages.add((Message)new AssistantMessage(assistant));
            }
            if (null == user) continue;
            messages.add((Message)new UserMessage(user));
        }
        return messages;
    }

    private PreparedStatement prepareAddStmt(String column) {
        InsertInto stmt = null;
        InsertInto stmtStart = QueryBuilder.insertInto((String)this.conf.schema.keyspace(), (String)this.conf.schema.table());
        for (CassandraChatMemoryConfig.SchemaColumn c : this.conf.schema.partitionKeys()) {
            stmt = (null != stmt ? stmt : stmtStart).value(c.name(), (Term)QueryBuilder.bindMarker((String)c.name()));
        }
        for (CassandraChatMemoryConfig.SchemaColumn c : this.conf.schema.clusteringKeys()) {
            stmt = stmt.value(c.name(), (Term)QueryBuilder.bindMarker((String)c.name()));
        }
        stmt = stmt.value(column, (Term)QueryBuilder.bindMarker((String)"message"));
        return this.conf.session.prepare(stmt.build());
    }

    private PreparedStatement prepareGetStatement() {
        Select stmt = QueryBuilder.selectFrom((String)this.conf.schema.keyspace(), (String)this.conf.schema.table()).all();
        for (CassandraChatMemoryConfig.SchemaColumn c : this.conf.schema.partitionKeys()) {
            stmt = (Select)stmt.whereColumn(c.name()).isEqualTo((Term)QueryBuilder.bindMarker((String)c.name()));
        }
        int i = 0;
        while (i + 1 < this.conf.schema.clusteringKeys().size()) {
            String columnName = this.conf.schema.clusteringKeys().get(i).name();
            stmt = (Select)stmt.whereColumn(columnName).isEqualTo((Term)QueryBuilder.bindMarker((String)columnName));
            ++i;
        }
        stmt = stmt.limit(QueryBuilder.bindMarker((String)"lastN"));
        return this.conf.session.prepare(stmt.build());
    }

    private PreparedStatement prepareDeleteStmt() {
        DeleteSelection stmt = null;
        DeleteSelection stmtStart = QueryBuilder.deleteFrom((String)this.conf.schema.keyspace(), (String)this.conf.schema.table());
        for (CassandraChatMemoryConfig.SchemaColumn c : this.conf.schema.partitionKeys()) {
            stmt = (Delete)(null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo((Term)QueryBuilder.bindMarker((String)c.name()));
        }
        int i = 0;
        while (i + 1 < this.conf.schema.clusteringKeys().size()) {
            String columnName = this.conf.schema.clusteringKeys().get(i).name();
            stmt = (Delete)stmt.whereColumn(columnName).isEqualTo((Term)QueryBuilder.bindMarker((String)columnName));
            ++i;
        }
        return this.conf.session.prepare(stmt.build());
    }
}

