/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.optimizer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.optimizer.Adadelta;
import ai.djl.training.optimizer.Adagrad;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Nag;
import ai.djl.training.optimizer.RmsProp;
import ai.djl.training.optimizer.Sgd;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

public abstract class Optimizer {
    protected float rescaleGrad;
    protected float clipGrad;
    private float weightDecays;
    private int beginNumUpdate;
    private int numUpdate;
    private Map<String, Integer> updateCounts = new ConcurrentHashMap<String, Integer>();

    public Optimizer(OptimizerBuilder<?> builder) {
        this.rescaleGrad = ((OptimizerBuilder)builder).rescaleGrad;
        this.weightDecays = ((OptimizerBuilder)builder).weightDecays;
        this.clipGrad = ((OptimizerBuilder)builder).clipGrad;
        this.beginNumUpdate = ((OptimizerBuilder)builder).beginNumUpdate;
    }

    public static Sgd.Builder sgd() {
        return new Sgd.Builder();
    }

    public static Nag.Builder nag() {
        return new Nag.Builder();
    }

    public static Adam.Builder adam() {
        return new Adam.Builder();
    }

    public static RmsProp.Builder rmsprop() {
        return new RmsProp.Builder();
    }

    public static Adagrad.Builder adagrad() {
        return new Adagrad.Builder();
    }

    public static Adadelta.Builder adadelta() {
        return new Adadelta.Builder();
    }

    protected float getWeightDecay() {
        return this.weightDecays;
    }

    protected int updateCount(String parameterId) {
        int count = this.updateCounts.compute(parameterId, (key, val) -> val == null ? this.beginNumUpdate + 1 : val + 1);
        this.numUpdate = Math.max(this.numUpdate, count);
        return this.numUpdate;
    }

    public abstract void update(String var1, NDArray var2, NDArray var3);

    protected NDArray withDefaultState(Map<String, Map<Device, NDArray>> state, String key, Device device, Function<String, NDArray> defaultFunction) {
        Map arrayMap = state.computeIfAbsent(key, k -> {
            ConcurrentHashMap<Device, NDArray> map = new ConcurrentHashMap<Device, NDArray>();
            NDArray s = (NDArray)defaultFunction.apply((String)k);
            s.detach();
            map.put(device, s);
            return map;
        });
        return arrayMap.computeIfAbsent(device, k -> ((NDArray)arrayMap.values().iterator().next()).toDevice(device, true));
    }

    public static abstract class OptimizerBuilder<T extends OptimizerBuilder> {
        private float rescaleGrad = 1.0f;
        private float weightDecays;
        private float clipGrad = -1.0f;
        private int beginNumUpdate;

        protected OptimizerBuilder() {
        }

        public T setRescaleGrad(float rescaleGrad) {
            this.rescaleGrad = rescaleGrad;
            return this.self();
        }

        public T optWeightDecays(float weightDecays) {
            this.weightDecays = weightDecays;
            return this.self();
        }

        public T optClipGrad(float clipGrad) {
            this.clipGrad = clipGrad;
            return this.self();
        }

        public T optBeginNumUpdate(int beginNumUpdate) {
            this.beginNumUpdate = beginNumUpdate;
            return this.self();
        }

        protected abstract T self();
    }
}

