package net.imglib2.img.display.imagej;

import ij.ImagePlus;
import ij.VirtualStack;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imagej.axis.AxisType;
import net.imglib2.Interval;
import net.imglib2.img.basictypeaccess.PlanarAccess;
import net.imglib2.img.basictypeaccess.array.ArrayDataAccess;
import net.imglib2.img.basictypeaccess.array.ByteArray;
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.img.basictypeaccess.array.IntArray;
import net.imglib2.img.basictypeaccess.array.ShortArray;
import net.imglib2.img.planar.PlanarImg;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.ARGBType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.IntervalIndexer;

/* loaded from: input_file:net/imglib2/img/display/imagej/PlanarImgToVirtualStack.class */
public class PlanarImgToVirtualStack extends AbstractVirtualStack {
    private final PlanarAccess<? extends ArrayDataAccess<?>> img;
    private final IntUnaryOperator indexer;
    private static final List<AxisType> ALLOWED_AXES = Arrays.asList(Axes.X, Axes.Y, Axes.CHANNEL, Axes.Z, Axes.TIME);

    public static boolean isSupported(ImgPlus<?> imgPlus) {
        ImgPlus fixAxes = ImgPlusViews.fixAxes(imgPlus);
        return (fixAxes.getImg() instanceof PlanarImg) && checkAxisOrder(getAxes(fixAxes)) && ImageProcessorUtils.isSupported((NativeType) fixAxes.randomAccess().get());
    }

    public static ImagePlus wrap(ImgPlus<?> imgPlus) {
        ImgPlus fixAxes = ImgPlusViews.fixAxes(imgPlus);
        PlanarImg img = fixAxes.getImg();
        if (!(img instanceof PlanarImg)) {
            throw new IllegalArgumentException("Image must be a PlanarImg.");
        }
        ImagePlus imagePlus = new ImagePlus(fixAxes.getName(), new PlanarImgToVirtualStack(img, getIndexer(fixAxes)));
        imagePlus.setDimensions(dimension(fixAxes, Axes.CHANNEL), dimension(fixAxes, Axes.Z), dimension(fixAxes, Axes.TIME));
        CalibrationUtils.copyCalibrationToImagePlus(fixAxes, imagePlus);
        return imagePlus;
    }

    private static int dimension(ImgPlus<?> imgPlus, AxisType axisType) {
        int dimensionIndex = imgPlus.dimensionIndex(axisType);
        if (dimensionIndex < 0) {
            return 1;
        }
        return (int) imgPlus.dimension(dimensionIndex);
    }

    public static VirtualStack wrap(PlanarImg<?, ?> planarImg) {
        return new PlanarImgToVirtualStack(planarImg, i -> {
            return i;
        });
    }

    private PlanarImgToVirtualStack(PlanarImg<?, ?> planarImg, IntUnaryOperator intUnaryOperator) {
        super((int) planarImg.dimension(0), (int) planarImg.dimension(1), initSize(planarImg), getBitDepth(planarImg.randomAccess().get()));
        this.img = planarImg;
        this.indexer = intUnaryOperator;
    }

    private static int initSize(Interval interval) {
        return IntStream.range(2, interval.numDimensions()).map(i -> {
            return (int) interval.dimension(i);
        }).reduce(1, (i2, i3) -> {
            return i2 * i3;
        });
    }

    @Override // net.imglib2.img.display.imagej.AbstractVirtualStack
    protected Object getPixelsZeroBasedIndex(int i) {
        return ((ArrayDataAccess) this.img.getPlane(this.indexer.applyAsInt(i))).getCurrentStorageArray();
    }

    @Override // net.imglib2.img.display.imagej.AbstractVirtualStack
    protected void setPixelsZeroBasedIndex(int i, Object obj) {
        setPlaneCastType(this.img, this.indexer.applyAsInt(i), wrapPixelsToAccess(obj));
    }

    private static <A extends ArrayDataAccess<?>> void setPlaneCastType(PlanarAccess<A> planarAccess, int i, ArrayDataAccess<?> arrayDataAccess) {
        planarAccess.setPlane(i, arrayDataAccess);
    }

    private ArrayDataAccess wrapPixelsToAccess(Object obj) {
        if (obj instanceof byte[]) {
            return new ByteArray((byte[]) obj);
        }
        if (obj instanceof short[]) {
            return new ShortArray((short[]) obj);
        }
        if (obj instanceof int[]) {
            return new IntArray((int[]) obj);
        }
        if (obj instanceof float[]) {
            return new FloatArray((float[]) obj);
        }
        throw new UnsupportedOperationException();
    }

    private static int getBitDepth(Type<?> type) {
        if (type instanceof UnsignedByteType) {
            return 8;
        }
        if (type instanceof UnsignedShortType) {
            return 16;
        }
        if (type instanceof ARGBType) {
            return 24;
        }
        if (type instanceof FloatType) {
            return 32;
        }
        throw new IllegalArgumentException("unsupported type");
    }

    private static IntUnaryOperator getIndexer(ImgPlus<?> imgPlus) {
        List<AxisType> axes = getAxes(imgPlus);
        if (!checkAxisOrder(axes)) {
            throw new IllegalArgumentException("Unsupported axis order, first axis must be X, second axis must be Y, and then optionally, arbitrary ordered: channel, Z and time.");
        }
        if (inPreferredOrder(axes)) {
            return i -> {
                return i;
            };
        }
        int[] iArr = {dimension(imgPlus, Axes.CHANNEL), dimension(imgPlus, Axes.Z), dimension(imgPlus, Axes.TIME)};
        int skip = getSkip(imgPlus, Axes.CHANNEL);
        int skip2 = getSkip(imgPlus, Axes.Z);
        int skip3 = getSkip(imgPlus, Axes.TIME);
        return i2 -> {
            int[] iArr2 = new int[3];
            IntervalIndexer.indexToPosition(i2, iArr, iArr2);
            return (skip * iArr2[0]) + (skip2 * iArr2[1]) + (skip3 * iArr2[2]);
        };
    }

    private static List<AxisType> getAxes(ImgPlus<?> imgPlus) {
        IntStream range = IntStream.range(0, imgPlus.numDimensions());
        imgPlus.getClass();
        return (List) range.mapToObj(imgPlus::axis).map((v0) -> {
            return v0.type();
        }).collect(Collectors.toList());
    }

    private static boolean checkAxisOrder(List<AxisType> list) {
        if (list.size() >= 2 && list.size() <= 5 && testUnique(list) && list.get(0) == Axes.X && list.get(1) == Axes.Y) {
            Stream<AxisType> stream = list.stream();
            List<AxisType> list2 = ALLOWED_AXES;
            list2.getClass();
            if (stream.allMatch((v1) -> {
                return r1.contains(v1);
            })) {
                return true;
            }
        }
        return false;
    }

    private static boolean inPreferredOrder(List<AxisType> list) {
        for (int i = 0; i < list.size() - 1; i++) {
            if (preferredPosition(list.get(i)) >= preferredPosition(list.get(i + 1))) {
                return false;
            }
        }
        return true;
    }

    private static int preferredPosition(AxisType axisType) {
        if (axisType == Axes.X) {
            return 0;
        }
        if (axisType == Axes.Y) {
            return 1;
        }
        if (axisType == Axes.CHANNEL) {
            return 2;
        }
        if (axisType == Axes.Z) {
            return 3;
        }
        if (axisType == Axes.TIME) {
            return 4;
        }
        throw new IllegalArgumentException("unknown axis");
    }

    private static int getSkip(ImgPlus<?> imgPlus, AxisType axisType) {
        return IntStream.range(2, imgPlus.dimensionIndex(axisType)).map(i -> {
            return (int) imgPlus.dimension(i);
        }).reduce(1, (i2, i3) -> {
            return i2 * i3;
        });
    }

    private static <T> boolean testUnique(List<T> list) {
        return new HashSet(list).size() == list.size();
    }
}
