/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.examples.gpu;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.runtime.JCuda;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.externalresource.ExternalResourceInfo;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.api.common.serialization.SimpleStringEncoder;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.connector.source.Source;
import org.apache.flink.connector.datagen.source.DataGeneratorSource;
import org.apache.flink.connector.datagen.source.GeneratorFunction;
import org.apache.flink.connector.file.sink.FileSink;
import org.apache.flink.core.fs.Path;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.ParameterTool;
import org.apache.flink.util.Preconditions;

public class MatrixVectorMul {
    private static final int DEFAULT_DIM = 10;
    private static final int DEFAULT_DATA_SIZE = 100;
    private static final String DEFAULT_RESOURCE_NAME = "gpu";

    public static void main(String[] args) throws Exception {
        ParameterTool params = ParameterTool.fromArgs((String[])args);
        System.out.println("Usage: MatrixVectorMul [--output <path>] [--dimension <dimension> --data-size <data_size>] [--resource-name <resource_name>]");
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.getConfig().setGlobalJobParameters((ExecutionConfig.GlobalJobParameters)params);
        int dimension = params.getInt("dimension", 10);
        int dataSize = params.getInt("data-size", 100);
        String resourceName = params.get("resource-name", DEFAULT_RESOURCE_NAME);
        GeneratorFunction & Serializable generatorFunction = (GeneratorFunction & Serializable)index -> {
            ArrayList<Float> randomRecord = new ArrayList<Float>();
            for (int i = 0; i < dimension; ++i) {
                randomRecord.add(Float.valueOf((float)Math.random()));
            }
            return randomRecord;
        };
        DataGeneratorSource generatorSource = new DataGeneratorSource((GeneratorFunction)generatorFunction, (long)dataSize, Types.LIST((TypeInformation)Types.FLOAT));
        SingleOutputStreamOperator result = env.fromSource((Source)generatorSource, WatermarkStrategy.noWatermarks(), "Vectors Source").map((MapFunction)new Multiplier(dimension, resourceName));
        if (params.has("output")) {
            result.sinkTo((Sink)FileSink.forRowFormat((Path)new Path(params.get("output")), (Encoder)new SimpleStringEncoder()).build());
        } else {
            System.out.println("Printing result to stdout. Use --output to specify output path.");
            result.print();
        }
        env.execute("Matrix-Vector Multiplication");
    }

    private static final class Multiplier
    extends RichMapFunction<List<Float>, List<Float>> {
        private final int dimension;
        private final String resourceName;
        private Pointer matrixPointer;

        Multiplier(int dimension, String resourceName) {
            this.dimension = dimension;
            this.resourceName = resourceName;
        }

        public void open(OpenContext openContext) {
            String originTempDir = System.getProperty("java.io.tmpdir");
            String newTempDir = originTempDir + "/jcuda-" + String.valueOf(UUID.randomUUID());
            System.setProperty("java.io.tmpdir", newTempDir);
            Set externalResourceInfos = this.getRuntimeContext().getExternalResourceInfos(this.resourceName);
            Preconditions.checkState((!externalResourceInfos.isEmpty() ? 1 : 0) != 0, (Object)"The MatrixVectorMul needs at least one GPU device while finding 0 GPU.");
            Optional firstIndexOptional = ((ExternalResourceInfo)externalResourceInfos.iterator().next()).getProperty("index");
            Preconditions.checkState((boolean)firstIndexOptional.isPresent());
            this.matrixPointer = new Pointer();
            float[] matrix = new float[this.dimension * this.dimension];
            for (int i = 0; i < this.dimension * this.dimension; ++i) {
                matrix[i] = (float)Math.random();
            }
            JCuda.cudaSetDevice((int)Integer.parseInt((String)firstIndexOptional.get()));
            JCublas.cublasInit();
            JCublas.cublasAlloc((int)(this.dimension * this.dimension), (int)4, (Pointer)this.matrixPointer);
            JCublas.cublasSetVector((int)(this.dimension * this.dimension), (int)4, (Pointer)Pointer.to((float[])matrix), (int)1, (Pointer)this.matrixPointer, (int)1);
            System.setProperty("java.io.tmpdir", originTempDir);
        }

        public List<Float> map(List<Float> value) {
            float[] input = new float[this.dimension];
            float[] output = new float[this.dimension];
            Pointer inputPointer = new Pointer();
            Pointer outputPointer = new Pointer();
            for (int i = 0; i < this.dimension; ++i) {
                input[i] = value.get(i).floatValue();
                output[i] = 0.0f;
            }
            JCublas.cublasAlloc((int)this.dimension, (int)4, (Pointer)inputPointer);
            JCublas.cublasAlloc((int)this.dimension, (int)4, (Pointer)outputPointer);
            JCublas.cublasSetVector((int)this.dimension, (int)4, (Pointer)Pointer.to((float[])input), (int)1, (Pointer)inputPointer, (int)1);
            JCublas.cublasSetVector((int)this.dimension, (int)4, (Pointer)Pointer.to((float[])output), (int)1, (Pointer)outputPointer, (int)1);
            JCublas.cublasSgemv((char)'n', (int)this.dimension, (int)this.dimension, (float)1.0f, (Pointer)this.matrixPointer, (int)this.dimension, (Pointer)inputPointer, (int)1, (float)0.0f, (Pointer)outputPointer, (int)1);
            JCublas.cublasGetVector((int)this.dimension, (int)4, (Pointer)outputPointer, (int)1, (Pointer)Pointer.to((float[])output), (int)1);
            JCublas.cublasFree((Pointer)inputPointer);
            JCublas.cublasFree((Pointer)outputPointer);
            ArrayList<Float> outputList = new ArrayList<Float>();
            for (int i = 0; i < this.dimension; ++i) {
                outputList.add(Float.valueOf(output[i]));
            }
            return outputList;
        }

        public void close() {
            JCublas.cublasFree((Pointer)this.matrixPointer);
            JCublas.cublasShutdown();
        }
    }
}

