/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.avatar.logProcessor.leRobot;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import us.ihmc.avatar.logProcessor.leRobot.LeRobotDatasetEpisode;
import us.ihmc.avatar.logProcessor.leRobot.LeRobotDatasetTools;
import us.ihmc.avatar.scs2.SCS2LogSessionWithVideo;
import us.ihmc.commons.exception.DefaultExceptionHandler;
import us.ihmc.commons.exception.ExceptionHandler;
import us.ihmc.commons.nio.FileTools;
import us.ihmc.commons.nio.WriteOption;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.robotSide.RobotSide;
import us.ihmc.robotics.robotSide.SideDependentList;
import us.ihmc.scs2.session.log.LogDataReader;
import us.ihmc.tools.io.JSONFileTools;
import us.ihmc.yoVariables.variable.YoBoolean;
import us.ihmc.yoVariables.variable.YoDouble;

public class LeRobotDataset {
    private final String name;
    private final Path directory;
    private final Path dataPath;
    private final Path metaPath;
    private final Path videosPath;
    private final Path dataChunk0Path;
    private final SideDependentList<Path> zedVideoDirs = new SideDependentList();
    private final Path episodesJsonlPath;
    private final Path episodeStatsJsonlPath;
    private final Path infoJsonPath;
    private final Path tasksJsonlPath;
    private final List<String> taskNames = new ArrayList<String>();
    private final List<LeRobotDatasetEpisode> episodes = new ArrayList<LeRobotDatasetEpisode>();
    private long totalFrames = 0L;
    private boolean usePerfectTimestamps = true;

    public LeRobotDataset(Path directory) {
        this.directory = directory;
        this.name = directory.getFileName().toString();
        this.dataPath = directory.resolve("data");
        this.metaPath = directory.resolve("meta");
        this.videosPath = directory.resolve("videos");
        this.dataChunk0Path = this.dataPath.resolve("chunk-000");
        for (RobotSide side : RobotSide.values) {
            this.zedVideoDirs.put((Enum)side, (Object)this.videosPath.resolve("chunk-000/observation.images.cam_zed_" + side.getLowerCaseName()));
        }
        this.episodesJsonlPath = this.metaPath.resolve("episodes.jsonl");
        this.episodeStatsJsonlPath = this.metaPath.resolve("episodes_stats.jsonl");
        this.infoJsonPath = this.metaPath.resolve("info.json");
        this.tasksJsonlPath = this.metaPath.resolve("tasks.jsonl");
    }

    public void mkdirs() {
        FileTools.ensureDirectoryExists((Path)this.directory, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureDirectoryExists((Path)this.dataPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureDirectoryExists((Path)this.metaPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureDirectoryExists((Path)this.videosPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureDirectoryExists((Path)this.dataChunk0Path, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        for (RobotSide side : RobotSide.values) {
            FileTools.ensureDirectoryExists((Path)((Path)this.zedVideoDirs.get((Enum)side)), (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        }
        FileTools.ensureFileExists((Path)this.episodesJsonlPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureFileExists((Path)this.episodeStatsJsonlPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureFileExists((Path)this.infoJsonPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        FileTools.ensureFileExists((Path)this.tasksJsonlPath, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
    }

    public void loadData() {
        this.taskNames.clear();
        JSONFileTools.loadLines((Path)this.tasksJsonlPath, lineRoot -> this.taskNames.add(lineRoot.get("task").textValue()));
        this.episodes.clear();
        JSONFileTools.loadLines((Path)this.episodesJsonlPath, lineRoot -> {
            int episodeIndex = this.episodes.size();
            String taskName = lineRoot.get("tasks").get(0).textValue();
            int length = lineRoot.get("length").intValue();
            LeRobotDatasetEpisode episode = new LeRobotDatasetEpisode(episodeIndex, taskName, length, this.totalFrames, this.episodesJsonlPath, this.episodeStatsJsonlPath, this.dataChunk0Path, this.zedVideoDirs);
            episode.loadParquetData();
            this.episodes.add(episode);
            this.totalFrames += (long)length;
        });
    }

    public void addEpisode(String taskName, SCS2LogSessionWithVideo session, Consumer<Runnable> frameProcessingQueue) {
        if (!this.taskNames.contains(taskName)) {
            this.taskNames.add(taskName);
            this.writeTaskJsonlLine(taskName);
        }
        int episodeIndex = this.episodes.size();
        LeRobotDatasetEpisode episode = new LeRobotDatasetEpisode(episodeIndex, taskName, 0L, this.totalFrames, this.episodesJsonlPath, this.episodeStatsJsonlPath, this.dataChunk0Path, this.zedVideoDirs);
        episode.startGeneratingEpisode(session, this::writeMetaJson, frameProcessingQueue, this.usePerfectTimestamps);
        this.episodes.add(episode);
    }

    public void calculateEpisode(String taskName, SCS2LogSessionWithVideo session, Consumer<Runnable> frameProcessingQueue) {
        if (!this.taskNames.contains(taskName)) {
            this.taskNames.add(taskName);
            this.writeTaskJsonlLine(taskName);
        }
        String highLevelController = "root.main.DRCControllerThread.DRCMomentumBasedController.HumanoidHighLevelControllerManager.";
        String wbcc = highLevelController + "HighLevelHumanoidControllerFactory.WholeBodyControllerCoreFactory.WholeBodyControllerCore.";
        String feedbackController = wbcc + "WholeBodyFeedbackController.FeedbackControllerToolbox.";
        String booleanVarName = String.format("%sPELVIS_LINKisPointFBControllerEnabled", feedbackController);
        String timestampVarName = "root.LogDataReader.robotTime";
        YoBoolean recordingFlag = (YoBoolean)session.getRootRegistry().findVariable(booleanVarName);
        YoDouble timestamp = (YoDouble)session.getRootRegistry().findVariable(timestampVarName);
        LogDataReader reader = session.getLogDataReader();
        long totalFrames = reader.getNumberOfEntries();
        boolean currentlyRecording = false;
        int episodeStart = -1;
        for (long frame = 0L; frame < totalFrames; ++frame) {
            boolean flagValue;
            session.runTick();
            boolean bl = flagValue = (int)timestamp.getValue() % 1000 == 0;
            if (flagValue && !currentlyRecording) {
                episodeStart = (int)frame;
                currentlyRecording = true;
                continue;
            }
            if (flagValue || !currentlyRecording) continue;
            int episodeEnd = (int)frame;
            int episodeLength = episodeEnd - episodeStart;
            System.out.println(episodeLength);
            currentlyRecording = false;
            episodeStart = -1;
        }
        if (currentlyRecording) {
            int episodeEnd = (int)totalFrames;
            int episodeLength = episodeEnd - episodeStart;
            System.out.println(episodeLength);
        }
    }

    public void removeEpisode(int index) throws IOException {
        if (this.episodes.isEmpty()) {
            LogTools.warn((String)"No episodes to remove.");
            return;
        }
        LeRobotDatasetEpisode lastEpisode = this.episodes.get(index);
        String episodeName = lastEpisode.getEpisodeName();
        Path parquetToDelete = this.dataChunk0Path.resolve(episodeName + ".parquet");
        if (!Files.exists(parquetToDelete, new LinkOption[0])) {
            LogTools.warn((String)("Parquet does not exist: " + String.valueOf(parquetToDelete)));
            return;
        }
        FileTools.deleteQuietly((Path)parquetToDelete);
        this.changeNumbers(index, this.episodes.size(), this.dataChunk0Path, ".parquet");
        LogTools.info((String)("Deleted Parquet: " + String.valueOf(parquetToDelete)));
        for (RobotSide side : RobotSide.values) {
            Path mp4Path = ((Path)this.zedVideoDirs.get((Enum)side)).resolve(episodeName + ".mp4");
            if (!Files.exists(mp4Path, new LinkOption[0])) {
                LogTools.warn((String)("MP4 does not exist: " + String.valueOf(mp4Path)));
                return;
            }
            FileTools.deleteQuietly((Path)mp4Path);
            this.changeNumbers(index, this.episodes.size(), (Path)this.zedVideoDirs.get((Enum)side), ".mp4");
            LogTools.info((String)("Deleted MP4: " + String.valueOf(mp4Path)));
        }
        this.removeLineFromJsonl(this.episodesJsonlPath, index);
        this.removeLineFromJsonl(this.episodeStatsJsonlPath, index);
        this.shiftEpisodeIndicesInJsonl(this.episodesJsonlPath, index);
        this.shiftEpisodeIndicesInJsonl(this.episodeStatsJsonlPath, index);
        this.episodes.remove(this.episodes.size() - 1);
        LogTools.info((String)("Removed episode: " + episodeName));
        this.regenerateAndRewriteMetadata();
    }

    private void changeNumbers(int index, int finalNumber, Path fileSpot, String fileType) {
        for (int i = index + 1; i < finalNumber; ++i) {
            LeRobotDatasetEpisode moveEpisode = this.episodes.get(i);
            String episodeName = moveEpisode.getEpisodeName();
            LeRobotDatasetEpisode newEpisode = this.episodes.get(i - 1);
            String newEpisodeName = newEpisode.getEpisodeName();
            Path parquetToMove = fileSpot.resolve(episodeName + fileType);
            Path parquetSpot = fileSpot.resolve(newEpisodeName + fileType);
            try {
                Files.move(parquetToMove, parquetSpot, new CopyOption[0]);
                continue;
            }
            catch (IOException e) {
                LogTools.error((String)("Failed to move parquet file: " + String.valueOf(parquetToMove)));
            }
        }
    }

    private void shiftEpisodeIndicesInJsonl(Path jsonlPath, int removedIndex) throws IOException {
        List<String> allLines = Files.readAllLines(jsonlPath);
        if (allLines.isEmpty()) {
            LogTools.warn((String)("JSONL is empty, nothing to shift: " + String.valueOf(jsonlPath)));
            return;
        }
        ObjectMapper mapper = new ObjectMapper();
        ArrayList<String> rewritten = new ArrayList<String>(allLines.size());
        for (int lineIdx = 0; lineIdx < allLines.size(); ++lineIdx) {
            int oldIndex;
            String line = allLines.get(lineIdx).trim();
            if (line.isEmpty()) {
                rewritten.add(line);
                continue;
            }
            JsonNode root = mapper.readTree(line);
            JsonNode episodeIndexNode = root.get("episode_index");
            if (episodeIndexNode != null && episodeIndexNode.isInt() && (oldIndex = episodeIndexNode.intValue()) > removedIndex) {
                ((ObjectNode)root).put("episode_index", oldIndex - 1);
            }
            rewritten.add(mapper.writeValueAsString((Object)root));
        }
        Files.write(jsonlPath, rewritten, StandardOpenOption.TRUNCATE_EXISTING);
        LogTools.info((String)("Shifted episode_index in JSONL: " + String.valueOf(jsonlPath) + " (removedIndex=" + removedIndex + ")."));
    }

    private void removeLineFromJsonl(Path jsonlPath, int index) throws IOException {
        List<String> allLines = Files.readAllLines(jsonlPath);
        if (allLines.isEmpty()) {
            LogTools.warn((String)("JSONL file is empty (nothing to remove): " + String.valueOf(jsonlPath)));
            return;
        }
        List<Object> linesToWrite = new ArrayList();
        List<Object> secondLines = new ArrayList();
        if (index > 0) {
            linesToWrite = allLines.subList(0, index - 1);
        }
        if (index < allLines.size()) {
            secondLines = allLines.subList(index + 1, allLines.size());
        }
        linesToWrite.addAll(secondLines);
        Files.write(jsonlPath, linesToWrite, StandardOpenOption.TRUNCATE_EXISTING);
        LogTools.info((String)("Removed last line from " + String.valueOf(jsonlPath) + " (now has " + linesToWrite.size() + " lines)."));
    }

    public void regenerateAndRewriteMetadata() {
        this.writeMetaJson();
        FileTools.write((Path)this.episodesJsonlPath, (byte[])new byte[0], (WriteOption)WriteOption.TRUNCATE, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        for (LeRobotDatasetEpisode episode : this.episodes) {
            episode.writeEpisodeJsonlLine();
        }
        FileTools.write((Path)this.episodeStatsJsonlPath, (byte[])new byte[0], (WriteOption)WriteOption.TRUNCATE, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        for (LeRobotDatasetEpisode episode : this.episodes) {
            LogTools.info((String)"Generating stats for %s...".formatted(episode.getEpisodeName()));
            episode.readDataAndWriteStatisticsJsonlLine();
        }
        FileTools.write((Path)this.tasksJsonlPath, (byte[])new byte[0], (WriteOption)WriteOption.TRUNCATE, (ExceptionHandler)DefaultExceptionHandler.PRINT_MESSAGE);
        for (String taskName : this.taskNames) {
            this.writeTaskJsonlLine(taskName);
        }
        LogTools.info((String)"All done regenerating and rewriting metadata.");
    }

    public void writeParquetData() {
        for (LeRobotDatasetEpisode episode : this.episodes) {
            episode.writeParquetData();
        }
    }

    public void writeMetaJson() {
        JSONFileTools.save((Path)this.infoJsonPath, rootNode -> {
            this.totalFrames = 0L;
            float fps = this.episodes.isEmpty() ? 1.0f : this.episodes.get(0).getFps();
            for (LeRobotDatasetEpisode episode : this.episodes) {
                this.totalFrames += (long)episode.getLength();
            }
            rootNode.put("codebase_version", "v2.1");
            rootNode.put("robot_type", "nadia");
            rootNode.put("total_episodes", this.episodes.size());
            rootNode.put("total_frames", this.totalFrames);
            rootNode.put("total_tasks", this.taskNames.size());
            rootNode.put("total_videos", 2 * this.episodes.size());
            rootNode.put("total_chunks", 1);
            rootNode.put("chunks_size", 1000);
            rootNode.put("fps", fps);
            ObjectNode splits = rootNode.putObject("splits");
            splits.put("train", "0:%d".formatted(this.episodes.size()));
            rootNode.put("data_path", "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet");
            rootNode.put("video_path", "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4");
            ObjectNode features = rootNode.putObject("features");
            for (RobotSide side : RobotSide.values) {
                ObjectNode cam = features.putObject("observation.images.cam_zed_%s".formatted(side.getLowerCaseName()));
                cam.put("dtype", "video");
                cam.putArray("shape").add(480).add(640).add(3);
                cam.putArray("names").add("height").add("width").add("channel");
                cam.putObject("video_info").put("video.fps", fps).put("video.codec", "mpeg4").put("video.pix_fmt", "yuv420p").put("video.is_depth_map", false).put("has_audio", false);
            }
            ObjectNode state = features.putObject("observation.state");
            state.put("dtype", "float32");
            state.putArray("shape").add(14);
            ArrayNode motors = state.putObject("names").putArray("motors");
            motors.add("left_gripper_x").add("left_gripper_y").add("left_gripper_z");
            motors.add("left_gripper_qx").add("left_gripper_qy").add("left_gripper_qz").add("left_gripper_qs");
            motors.add("right_gripper_x").add("right_gripper_y").add("right_gripper_z");
            motors.add("right_gripper_qx").add("right_gripper_qy").add("right_gripper_qz").add("right_gripper_qs");
            ObjectNode action = features.putObject("action");
            action.put("dtype", "float32");
            action.putArray("shape").add(14);
            motors = action.putObject("names").putArray("motors");
            motors.add("left_gripper_x").add("left_gripper_y").add("left_gripper_z");
            motors.add("left_gripper_qx").add("left_gripper_qy").add("left_gripper_qz").add("left_gripper_qs");
            motors.add("right_gripper_x").add("right_gripper_y").add("right_gripper_z");
            motors.add("right_gripper_qx").add("right_gripper_qy").add("right_gripper_qz").add("right_gripper_qs");
            ObjectNode episodeIndex = features.putObject("episode_index");
            episodeIndex.put("dtype", "int64");
            episodeIndex.putArray("shape").add(1);
            episodeIndex.put("names", (byte[])null);
            ObjectNode frameIndex = features.putObject("frame_index");
            frameIndex.put("dtype", "int64");
            frameIndex.putArray("shape").add(1);
            frameIndex.put("names", (byte[])null);
            ObjectNode timestamp = features.putObject("timestamp");
            timestamp.put("dtype", "float32");
            timestamp.putArray("shape").add(1);
            timestamp.put("names", (byte[])null);
            ObjectNode nextDone = features.putObject("next.done");
            nextDone.put("dtype", "bool");
            nextDone.putArray("shape").add(1);
            nextDone.put("names", (byte[])null);
            ObjectNode index = features.putObject("index");
            index.put("dtype", "int64");
            index.putArray("shape").add(1);
            index.put("names", (byte[])null);
            ObjectNode taskIndex = features.putObject("task_index");
            taskIndex.put("dtype", "int64");
            taskIndex.putArray("shape").add(1);
            taskIndex.put("names", (byte[])null);
        });
    }

    private void writeTaskJsonlLine(String taskName) {
        LeRobotDatasetTools.appendLine(this.tasksJsonlPath, JSONFileTools.getAsSingleLine(node -> {
            node.put("task_index", this.taskNames.size() - 1);
            node.put("task", taskName);
        }));
    }

    public String getName() {
        return this.name;
    }

    public Path getDirectory() {
        return this.directory;
    }

    public Path getDataPath() {
        return this.dataPath;
    }

    public Path getMetaPath() {
        return this.metaPath;
    }

    public Path getVideosPath() {
        return this.videosPath;
    }

    public Path getDataChunk0Path() {
        return this.dataChunk0Path;
    }

    public SideDependentList<Path> getZedVideoDirs() {
        return this.zedVideoDirs;
    }

    public List<String> getTaskNames() {
        return this.taskNames;
    }

    public List<LeRobotDatasetEpisode> getEpisodes() {
        return this.episodes;
    }

    public long getTotalFrames() {
        return this.totalFrames;
    }

    public void setUsePerfectTimestamps(boolean usePerfectTimestamps) {
        this.usePerfectTimestamps = usePerfectTimestamps;
    }

    public boolean getUsePerfectTimestamps() {
        return this.usePerfectTimestamps;
    }
}

