/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;

public class YoloTranslator
extends ObjectDetectionTranslator {
    public YoloTranslator(Builder builder) {
        super(builder);
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception {
        int[] classIndices = ((NDArray)list.get(0)).toType(DataType.INT32, true).flatten().toIntArray();
        double[] probs = ((NDArray)list.get(1)).toType(DataType.FLOAT64, true).flatten().toDoubleArray();
        NDArray boundingBoxes = (NDArray)list.get(2);
        int detected = Math.toIntExact(probs.length);
        NDArray xMin = boundingBoxes.get(":, 0").clip(0, this.imageWidth).div(this.imageWidth);
        NDArray yMin = boundingBoxes.get(":, 1").clip(0, this.imageHeight).div(this.imageHeight);
        NDArray xMax = boundingBoxes.get(":, 2").clip(0, this.imageWidth).div(this.imageWidth);
        NDArray yMax = boundingBoxes.get(":, 3").clip(0, this.imageHeight).div(this.imageHeight);
        float[] boxX = xMin.toFloatArray();
        float[] boxY = yMin.toFloatArray();
        float[] boxWidth = xMax.sub(xMin).toFloatArray();
        float[] boxHeight = yMax.sub(yMin).toFloatArray();
        ArrayList<String> retClasses = new ArrayList<String>(detected);
        ArrayList<Double> retProbs = new ArrayList<Double>(detected);
        ArrayList<BoundingBox> retBB = new ArrayList<BoundingBox>(detected);
        for (int i = 0; i < detected; ++i) {
            if (classIndices[i] < 0 || probs[i] < (double)this.threshold) continue;
            retClasses.add((String)this.classes.get(classIndices[i]));
            retProbs.add(probs[i]);
            Rectangle rect = new Rectangle(boxX[i], boxY[i], boxWidth[i], boxHeight[i]);
            retBB.add(rect);
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends ObjectDetectionTranslator.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public YoloTranslator build() {
            this.validate();
            return new YoloTranslator(this);
        }
    }
}

