/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;

public final class AlexNet {
    private AlexNet() {
    }

    public static Block alexNet(Builder builder) {
        return new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{11L, 11L}))).optStride(new Shape(new long[]{4L, 4L}))).setFilters(builder.numChannels[0])).build()).add(Activation::relu).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{5L, 5L}))).optPadding(new Shape(new long[]{2L, 2L}))).setFilters(builder.numChannels[1])).build()).add(Activation::relu).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).setFilters(builder.numChannels[2])).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).setFilters(builder.numChannels[3])).build()).add(Activation::relu).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).setFilters(builder.numChannels[4])).build()).add(Activation::relu).add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}))).add(Blocks.batchFlattenBlock()).add((Block)Linear.builder().setUnits((long)builder.numChannels[5]).build()).add(Activation::relu).add((Block)Dropout.builder().optRate(builder.dropOutRate).build()).add((Block)Linear.builder().setUnits((long)builder.numChannels[6]).build()).add(Activation::relu).add((Block)Dropout.builder().optRate(builder.dropOutRate).build()).add((Block)Linear.builder().setUnits(10L).build());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        float dropOutRate = 0.5f;
        int numLayers = 7;
        int[] numChannels = new int[]{96, 256, 384, 384, 256, 4096, 4096};

        Builder() {
        }

        public Builder setDropOutRate(float dropOutRate) {
            this.dropOutRate = dropOutRate;
            return this;
        }

        public Builder setNumChannels(int[] numChannels) {
            if (numChannels.length != this.numLayers) {
                throw new IllegalArgumentException("number of channels should be equal to " + this.numLayers);
            }
            this.numChannels = numChannels;
            return this;
        }

        public Block build() {
            return AlexNet.alexNet(this);
        }
    }
}

