/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.avatar.logProcessor.leRobot;

import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import org.bytedeco.opencv.opencv_core.Mat;
import us.ihmc.avatar.logProcessor.leRobot.LeRobotEpisodeRecord;
import us.ihmc.avatar.logProcessor.leRobot.LeRobotFloatStatisticsCalculator;
import us.ihmc.avatar.logProcessor.leRobot.LeRobotIntegerStatisticsCalculator;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.robotSide.RobotSide;
import us.ihmc.robotics.robotSide.SideDependentList;

public class LeRobotDatasetEpisodeStatistics {
    private final int[] sizes = new int[]{0, 0};
    private final SideDependentList<RGBL> sums = new SideDependentList((Object)new RGBL(), (Object)new RGBL());
    private final SideDependentList<RGBL> sumSquares = new SideDependentList((Object)new RGBL(), (Object)new RGBL());
    private final SideDependentList<RGB> means = new SideDependentList((Object)new RGB(), (Object)new RGB());
    private final SideDependentList<RGB> stds = new SideDependentList((Object)new RGB(), (Object)new RGB());
    private int length = 0;
    private final List<LeRobotFloatStatisticsCalculator> stateStats = new ArrayList<LeRobotFloatStatisticsCalculator>();
    private final List<LeRobotFloatStatisticsCalculator> actionStats = new ArrayList<LeRobotFloatStatisticsCalculator>();
    private final LeRobotIntegerStatisticsCalculator episodeIndexStats = new LeRobotIntegerStatisticsCalculator();
    private final LeRobotIntegerStatisticsCalculator frameIndexStats = new LeRobotIntegerStatisticsCalculator();
    private final LeRobotFloatStatisticsCalculator timestampStats = new LeRobotFloatStatisticsCalculator();
    private final LeRobotIntegerStatisticsCalculator nextDoneStats = new LeRobotIntegerStatisticsCalculator();
    private final LeRobotIntegerStatisticsCalculator indexStats = new LeRobotIntegerStatisticsCalculator();
    private final LeRobotIntegerStatisticsCalculator taskIndexStats = new LeRobotIntegerStatisticsCalculator();

    public void submitFrame(RobotSide side, Mat bgrMat) {
        int height = bgrMat.rows();
        int width = bgrMat.cols();
        int channels = bgrMat.channels();
        int totalPixels = height * width;
        int n = side.ordinal();
        this.sizes[n] = this.sizes[n] + totalPixels;
        RGBL sum = (RGBL)this.sums.get((Enum)side);
        RGBL sumSq = (RGBL)this.sumSquares.get((Enum)side);
        byte[] data = new byte[width * height * channels];
        bgrMat.data().get(data);
        for (int i = 0; i < totalPixels; ++i) {
            int offset = i * channels;
            int b = data[offset] & 0xFF;
            int g = data[offset + 1] & 0xFF;
            int r = data[offset + 2] & 0xFF;
            sum.b += (long)b;
            sum.g += (long)g;
            sum.r += (long)r;
            sumSq.b += (long)b * (long)b;
            sumSq.g += (long)g * (long)g;
            sumSq.r += (long)r * (long)r;
        }
    }

    public void processParquetRecord(LeRobotEpisodeRecord dataFrame) {
        int i;
        for (i = 0; i < dataFrame.state().size(); ++i) {
            if (this.length == 0) {
                this.stateStats.add(new LeRobotFloatStatisticsCalculator());
            }
            this.stateStats.get(i).addValue(dataFrame.state().get(i).floatValue());
        }
        for (i = 0; i < dataFrame.action().size(); ++i) {
            if (this.length == 0) {
                this.actionStats.add(new LeRobotFloatStatisticsCalculator());
            }
            this.actionStats.get(i).addValue(dataFrame.action().get(i).floatValue());
        }
        this.episodeIndexStats.addValue(dataFrame.episodeIndex());
        this.frameIndexStats.addValue(dataFrame.frameIndex());
        this.timestampStats.addValue(dataFrame.timestamp());
        this.nextDoneStats.addValue(dataFrame.nextDone() ? 1L : 0L);
        this.indexStats.addValue(dataFrame.index());
        this.taskIndexStats.addValue(dataFrame.taskIndex());
        ++this.length;
    }

    public void calculate() {
        for (RobotSide side : RobotSide.values) {
            int totalPixels = this.sizes[side.ordinal()];
            RGBL sum = (RGBL)this.sums.get((Enum)side);
            RGBL sumSq = (RGBL)this.sumSquares.get((Enum)side);
            RGB mean = (RGB)this.means.get((Enum)side);
            RGB std = (RGB)this.stds.get((Enum)side);
            mean.r = (float)sum.r / ((float)totalPixels * 255.0f);
            mean.g = (float)sum.g / ((float)totalPixels * 255.0f);
            mean.b = (float)sum.b / ((float)totalPixels * 255.0f);
            double meanR = (double)sum.r / (double)totalPixels;
            double meanG = (double)sum.g / (double)totalPixels;
            double meanB = (double)sum.b / (double)totalPixels;
            double varR = (double)sumSq.r / (double)totalPixels - meanR * meanR;
            double varG = (double)sumSq.g / (double)totalPixels - meanG * meanG;
            double varB = (double)sumSq.b / (double)totalPixels - meanB * meanB;
            std.r = (float)(Math.sqrt(Math.max(0.0, varR)) / 255.0);
            std.g = (float)(Math.sqrt(Math.max(0.0, varG)) / 255.0);
            std.b = (float)(Math.sqrt(Math.max(0.0, varB)) / 255.0);
            LogTools.info((String)"Mean RGB: R=%.3f G=%.3f B=%.3f".formatted(Float.valueOf(mean.r), Float.valueOf(mean.g), Float.valueOf(mean.b)));
            LogTools.info((String)"StdDev RGB: R=%.3f G=%.3f B=%.3f".formatted(Float.valueOf(std.r), Float.valueOf(std.g), Float.valueOf(std.b)));
        }
        for (LeRobotFloatStatisticsCalculator calculator : this.stateStats) {
            calculator.calculate();
        }
        for (LeRobotFloatStatisticsCalculator calculator : this.actionStats) {
            calculator.calculate();
        }
        this.episodeIndexStats.calculate();
        this.frameIndexStats.calculate();
        this.timestampStats.calculate();
        this.nextDoneStats.calculate();
        this.indexStats.calculate();
        this.taskIndexStats.calculate();
    }

    public void writeJson(ObjectNode stats, SideDependentList<Path> zedVideoDirs) {
        for (RobotSide side : RobotSide.values) {
            RGB mean = (RGB)this.means.get((Enum)side);
            RGB std = (RGB)this.stds.get((Enum)side);
            ObjectNode objectNode = stats.putObject(((Path)zedVideoDirs.get((Enum)side)).getFileName().toString());
            ArrayNode min = objectNode.putArray("min");
            min.addArray().addArray().add(0.0f);
            min.addArray().addArray().add(0.0f);
            min.addArray().addArray().add(0.0f);
            ArrayNode max = objectNode.putArray("max");
            max.addArray().addArray().add(1.0f);
            max.addArray().addArray().add(1.0f);
            max.addArray().addArray().add(1.0f);
            ArrayNode meanNode = objectNode.putArray("mean");
            meanNode.addArray().addArray().add(mean.r);
            meanNode.addArray().addArray().add(mean.g);
            meanNode.addArray().addArray().add(mean.b);
            ArrayNode stdNode = objectNode.putArray("std");
            stdNode.addArray().addArray().add(std.r);
            stdNode.addArray().addArray().add(std.g);
            stdNode.addArray().addArray().add(std.b);
            objectNode.putArray("count").add(this.sizes[side.ordinal()]);
        }
        ObjectNode state = stats.putObject("observation.state");
        ArrayNode min = state.putArray("min");
        ArrayNode max = state.putArray("max");
        ArrayNode mean = state.putArray("mean");
        ArrayNode std = state.putArray("std");
        for (LeRobotFloatStatisticsCalculator leRobotFloatStatisticsCalculator : this.stateStats) {
            min.add(leRobotFloatStatisticsCalculator.getMin());
            max.add(leRobotFloatStatisticsCalculator.getMax());
            mean.add(leRobotFloatStatisticsCalculator.getMean());
            std.add(leRobotFloatStatisticsCalculator.getStddev());
        }
        state.putArray("count").add(this.length);
        ObjectNode action = stats.putObject("action");
        min = action.putArray("min");
        max = action.putArray("max");
        mean = action.putArray("mean");
        std = action.putArray("std");
        for (LeRobotFloatStatisticsCalculator calculator : this.actionStats) {
            min.add(calculator.getMin());
            max.add(calculator.getMax());
            mean.add(calculator.getMean());
            std.add(calculator.getStddev());
        }
        action.putArray("count").add(this.length);
        ObjectNode objectNode = stats.putObject("episode_index");
        objectNode.putArray("min").add(this.episodeIndexStats.getMin());
        objectNode.putArray("max").add(this.episodeIndexStats.getMax());
        objectNode.putArray("mean").add(this.episodeIndexStats.getMean());
        objectNode.putArray("std").add(this.episodeIndexStats.getStddev());
        objectNode.putArray("count").add(this.length);
        ObjectNode objectNode2 = stats.putObject("frame_index");
        objectNode2.putArray("min").add(this.frameIndexStats.getMin());
        objectNode2.putArray("max").add(this.frameIndexStats.getMax());
        objectNode2.putArray("mean").add(this.frameIndexStats.getMean());
        objectNode2.putArray("std").add(this.frameIndexStats.getStddev());
        objectNode2.putArray("count").add(this.length);
        ObjectNode objectNode3 = stats.putObject("timestamp");
        objectNode3.putArray("min").add(this.timestampStats.getMin());
        objectNode3.putArray("max").add(this.timestampStats.getMax());
        objectNode3.putArray("mean").add(this.timestampStats.getMean());
        objectNode3.putArray("std").add(this.timestampStats.getStddev());
        objectNode3.putArray("count").add(this.length);
        ObjectNode objectNode4 = stats.putObject("next.done");
        objectNode4.putArray("min").add(this.nextDoneStats.getMin() == 1L);
        objectNode4.putArray("max").add(this.nextDoneStats.getMax() == 1L);
        objectNode4.putArray("mean").add(this.nextDoneStats.getMean());
        objectNode4.putArray("std").add(this.nextDoneStats.getStddev());
        objectNode4.putArray("count").add(this.length);
        ObjectNode objectNode5 = stats.putObject("index");
        objectNode5.putArray("min").add(this.indexStats.getMin());
        objectNode5.putArray("max").add(this.indexStats.getMax());
        objectNode5.putArray("mean").add(this.indexStats.getMean());
        objectNode5.putArray("std").add(this.indexStats.getStddev());
        objectNode5.putArray("count").add(this.length);
        ObjectNode objectNode6 = stats.putObject("task_index");
        objectNode6.putArray("min").add(this.taskIndexStats.getMin());
        objectNode6.putArray("max").add(this.taskIndexStats.getMax());
        objectNode6.putArray("mean").add(this.taskIndexStats.getMean());
        objectNode6.putArray("std").add(this.taskIndexStats.getStddev());
        objectNode6.putArray("count").add(this.length);
    }

    private class RGBL {
        long r;
        long g;
        long b = 0L;

        private RGBL() {
        }
    }

    private class RGB {
        float r;
        float g;
        float b = 0.0f;

        private RGB() {
        }
    }
}

