/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.mxnet.engine.MxNDArray;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.MxnetLibrary;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import com.sun.jna.Pointer;
import java.util.Arrays;

public class MxParameterServer
extends NativeResource
implements ParameterServer {
    public MxParameterServer(Optimizer optimizer) {
        super(MxParameterServer.createdKVStore());
        JnaUtils.parameterStoreSetUpdater(this.getHandle(), null, new OptimizerCallback(optimizer), null);
    }

    public void init(String parameterId, NDArray[] values) {
        Object[] keys = new String[values.length];
        Arrays.fill(keys, parameterId);
        NDList vals = new NDList(values);
        JnaUtils.parameterStoreInit(this.getHandle(), values.length, (String[])keys, vals);
    }

    public void push(String parameterId, NDArray[] grads, int priority) {
        Object[] keys = new String[grads.length];
        Arrays.fill(keys, parameterId);
        NDList vals = new NDList(grads);
        JnaUtils.parameterStorePush(this.getHandle(), grads.length, (String[])keys, vals, priority);
    }

    public void pull(String parameterId, NDArray[] weights, int priority) {
        Object[] keys = new String[weights.length];
        Arrays.fill(keys, parameterId);
        NDList vals = new NDList(weights);
        JnaUtils.parameterStorePull(this.getHandle(), weights.length, (String[])keys, vals, priority);
    }

    private static Pointer createdKVStore() {
        return JnaUtils.parameterStoreCreate("device");
    }

    @Override
    public void close() {
        Pointer pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            JnaUtils.parameterStoreClose(pointer);
        }
    }

    private static final class OptimizerCallback
    implements MxnetLibrary.MXKVStoreStrUpdater {
        private Optimizer optimizer;

        OptimizerCallback(Optimizer optimizer) {
            this.optimizer = optimizer;
        }

        @Override
        public void apply(String parameterId, Pointer recv, Pointer local, Pointer handle) {
            try (MxNDManager manager = MxNDManager.getSystemManager().newSubManager();){
                MxNDArray grad = manager.create(recv);
                MxNDArray weight = manager.create(local);
                grad.setShouldFree(false);
                weight.setShouldFree(false);
                this.optimizer.update(parameterId, (NDArray)weight, (NDArray)grad);
            }
        }
    }
}

