/*
 * Decompiled with CFR 0.152.
 */
package com.dtsx.astra.sdk.cassio;

import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.Statement;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.dtsx.astra.sdk.cassio.AbstractCassandraTable;
import com.dtsx.astra.sdk.cassio.AnnQuery;
import com.dtsx.astra.sdk.cassio.AnnResult;
import com.dtsx.astra.sdk.cassio.CassandraSimilarityMetric;
import com.dtsx.astra.sdk.cassio.MetadataVectorRecord;
import io.stargate.sdk.utils.AnsiUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MetadataVectorTable
extends AbstractCassandraTable<MetadataVectorRecord> {
    private static final Logger log = LoggerFactory.getLogger(MetadataVectorTable.class);
    private final int vectorDimension;
    private final CassandraSimilarityMetric similarityMetric;

    public MetadataVectorTable(CqlSession session, String keyspaceName, String tableName, int vectorDimension) {
        this(session, keyspaceName, tableName, vectorDimension, CassandraSimilarityMetric.COSINE);
    }

    public MetadataVectorTable(CqlSession session, String keyspaceName, String tableName, int vectorDimension, CassandraSimilarityMetric metric) {
        super(session, keyspaceName, tableName);
        this.vectorDimension = vectorDimension;
        this.similarityMetric = metric;
        this.create();
    }

    @Override
    public void create() {
        String cqlQueryCreateTable = "CREATE TABLE IF NOT EXISTS " + this.tableName + " (row_id text, attributes_blob text, body_blob text, metadata_s map<text, text>, vector vector<float, " + this.vectorDimension + ">, PRIMARY KEY (row_id))";
        this.cqlSession.execute(cqlQueryCreateTable);
        log.info("Table '{}' has been created (if needed).", (Object)this.tableName);
        this.cqlSession.execute("CREATE CUSTOM INDEX IF NOT EXISTS idx_vector_" + this.tableName + " ON " + this.tableName + " (vector) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '" + this.similarityMetric.getOption() + "'};");
        log.info("Index '{}' has been created (if needed).", (Object)("idx_vector_" + this.tableName));
        this.cqlSession.execute("CREATE CUSTOM INDEX IF NOT EXISTS eidx_metadata_s_" + this.tableName + " ON " + this.tableName + " (ENTRIES(metadata_s)) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';");
        log.info("Index '{}' has been created (if needed).", (Object)("eidx_metadata_s_" + this.tableName));
    }

    @Override
    public void put(MetadataVectorRecord row) {
        this.cqlSession.execute((Statement)row.insertStatement(this.keyspaceName, this.tableName));
    }

    private AnnResult<MetadataVectorRecord> mapResult(Row cqlRow) {
        if (cqlRow == null) {
            return null;
        }
        AnnResult<MetadataVectorRecord> res = new AnnResult<MetadataVectorRecord>();
        res.setEmbedded(this.mapRow(cqlRow));
        res.setSimilarity(cqlRow.getFloat("similarity"));
        log.debug("Result similarity '{}' for embedded id='{}'", (Object)Float.valueOf(res.getSimilarity()), (Object)((MetadataVectorRecord)res.getEmbedded()).getRowId());
        return res;
    }

    @Override
    public MetadataVectorRecord mapRow(Row cqlRow) {
        if (cqlRow == null) {
            return null;
        }
        MetadataVectorRecord record = new MetadataVectorRecord();
        record.setRowId(cqlRow.getString("row_id"));
        record.setBody(cqlRow.getString("body_blob"));
        record.setVector(((CqlVector)Objects.requireNonNull(cqlRow.getObject("vector"))).stream().collect(Collectors.toList()));
        if (cqlRow.getColumnDefinitions().contains("attributes_blob")) {
            record.setAttributes(cqlRow.getString("attributes_blob"));
        }
        if (cqlRow.getColumnDefinitions().contains("metadata_s")) {
            record.setMetadata(cqlRow.getMap("metadata_s", String.class, String.class));
        }
        return record;
    }

    public List<AnnResult<MetadataVectorRecord>> similaritySearch(AnnQuery query) {
        StringBuilder cqlQuery = new StringBuilder("SELECT row_id,vector,body_blob,attributes_blob,metadata_s,");
        cqlQuery.append(query.getMetric().getFunction()).append("(vector, :vector) as ").append("similarity");
        cqlQuery.append(" FROM ").append(this.tableName);
        if (query.getMetaData() != null && !query.getMetaData().isEmpty()) {
            cqlQuery.append(" WHERE ");
            boolean first = true;
            for (Map.Entry<String, String> entry : query.getMetaData().entrySet()) {
                if (!first) {
                    cqlQuery.append(" AND ");
                }
                cqlQuery.append("metadata_s").append("['").append(entry.getKey()).append("'] = '").append(entry.getValue()).append("'");
                first = false;
            }
        }
        cqlQuery.append(" ORDER BY vector ANN OF :vector ");
        cqlQuery.append(" LIMIT :maxRecord");
        log.debug("Query on table '{}' with vector size '{}' and max record='{}'", new Object[]{AnsiUtils.yellow((String)this.tableName), AnsiUtils.cyan((String)("[" + query.getEmbeddings().size() + "]")), AnsiUtils.cyan((String)("" + (query.getRecordCount() > 0 ? query.getRecordCount() : 4)))});
        return this.cqlSession.execute((Statement)SimpleStatement.builder((String)cqlQuery.toString()).addNamedValue("vector", (Object)CqlVector.newInstance(query.getEmbeddings())).addNamedValue("maxRecord", (Object)(query.getRecordCount() > 0 ? query.getRecordCount() : 4)).build()).all().stream().map(this::mapResult).filter(r -> (double)r.getSimilarity() >= query.getThreshold()).collect(Collectors.toList());
    }

    public int getVectorDimension() {
        return this.vectorDimension;
    }

    public CassandraSimilarityMetric getSimilarityMetric() {
        return this.similarityMetric;
    }
}

