/*
 * Decompiled with CFR 0.152.
 */
package org.redfx.strange.local;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import org.redfx.strange.Complex;
import org.redfx.strange.Gate;
import org.redfx.strange.Program;
import org.redfx.strange.QuantumExecutionEnvironment;
import org.redfx.strange.Qubit;
import org.redfx.strange.Result;
import org.redfx.strange.Step;
import org.redfx.strange.gate.Identity;
import org.redfx.strange.gate.PermutationGate;
import org.redfx.strange.gate.ProbabilitiesGate;
import org.redfx.strange.gate.Swap;
import org.redfx.strange.local.Computations;

public class SimpleQuantumExecutionEnvironment
implements QuantumExecutionEnvironment {
    public static void dbg(String s) {
        String dbp = System.getProperty("dbg", "false");
        if (dbp.equals("true")) {
            System.err.println("[DBG] " + s);
        }
    }

    @Override
    public Result runProgram(Program p) {
        SimpleQuantumExecutionEnvironment.dbg("runProgram ");
        int nQubits = p.getNumberQubits();
        Qubit[] qubit = new Qubit[nQubits];
        for (int i = 0; i < nQubits; ++i) {
            qubit[i] = new Qubit();
        }
        int dim = 1 << nQubits;
        double[] initalpha = p.getInitialAlphas();
        Complex[] probs = new Complex[dim];
        for (int i = 0; i < dim; ++i) {
            probs[i] = Complex.ONE;
            for (int j = 0; j < nQubits; ++j) {
                int pw = nQubits - j - 1;
                int pt = 1 << pw;
                int div = i / pt;
                int md = div % 2;
                probs[i] = md == 0 ? probs[i].mul(initalpha[j]) : probs[i].mul(Math.sqrt(1.0 - initalpha[j] * initalpha[j]));
            }
        }
        List<Step> steps = p.getSteps();
        List<Step> simpleSteps = p.getDecomposedSteps();
        if (simpleSteps == null) {
            simpleSteps = new ArrayList<Step>();
            for (Step step : steps) {
                simpleSteps.addAll(Computations.decomposeStep(step, nQubits));
            }
            p.setDecomposedSteps(simpleSteps);
        }
        Result result = new Result(nQubits, steps.size());
        int cnt = 0;
        result.setIntermediateProbability(0, probs);
        SimpleQuantumExecutionEnvironment.dbg("START RUN, number of steps = " + simpleSteps.size());
        for (Step step : simpleSteps) {
            if (step.getGates().isEmpty()) continue;
            SimpleQuantumExecutionEnvironment.dbg("RUN STEP " + step + ", cnt = " + cnt);
            ++cnt;
            SimpleQuantumExecutionEnvironment.dbg("before this step, probs = ");
            probs = this.applyStep(step, probs, qubit);
            SimpleQuantumExecutionEnvironment.dbg("after this step, probs = " + probs);
            int idx = step.getComplexStep();
            if (idx <= -1) continue;
            result.setIntermediateProbability(idx, probs);
        }
        SimpleQuantumExecutionEnvironment.dbg("DONE RUN, probability vector = " + probs);
        this.printProbs(probs);
        double[] qp = this.calculateQubitStatesFromVector(probs);
        for (int i = 0; i < nQubits; ++i) {
            qubit[i].setProbability(qp[i]);
        }
        result.measureSystem();
        p.setResult(result);
        return result;
    }

    @Override
    public void runProgram(Program p, Consumer<Result> result) {
        Thread t = new Thread(() -> result.accept(this.runProgram(p)));
        t.start();
    }

    private void printProbs(Complex[] p) {
        Complex.printArray(p);
    }

    private List<Step> decomposeSteps(List<Step> steps) {
        return steps;
    }

    private Complex[] applyStep(Step step, Complex[] vector, Qubit[] qubits) {
        SimpleQuantumExecutionEnvironment.dbg("start applystep, vectorsize = " + vector.length + ", ql = " + qubits.length);
        long s0 = System.currentTimeMillis();
        List<Gate> gates = step.getGates();
        if (!gates.isEmpty() && gates.get(0) instanceof ProbabilitiesGate) {
            ProbabilitiesGate probGate = (ProbabilitiesGate)gates.get(0);
            probGate.setProbabilites(vector);
            return vector;
        }
        if (gates.size() == 1 && gates.get(0) instanceof PermutationGate) {
            PermutationGate pg = (PermutationGate)gates.get(0);
            return Computations.permutateVector(vector, pg.getIndex1(), pg.getIndex2());
        }
        Complex[] result = new Complex[vector.length];
        boolean vdd = true;
        if (vdd) {
            result = Computations.calculateNewState(gates, vector, qubits.length);
        } else {
            SimpleQuantumExecutionEnvironment.dbg("start calcstepmatrix with gates " + gates);
            Complex[][] a = this.calculateStepMatrix(gates, qubits.length);
            SimpleQuantumExecutionEnvironment.dbg("done calcstepmatrix");
            SimpleQuantumExecutionEnvironment.dbg("vector");
            if (a.length != result.length) {
                System.err.println("fatal issue calculating step for gates " + gates);
                throw new RuntimeException("Wrong length of matrix or probability vector: expected " + result.length + " but got " + a.length);
            }
            SimpleQuantumExecutionEnvironment.dbg("start matrix-vector multiplication for vector size = " + vector.length);
            for (int i = 0; i < vector.length; ++i) {
                result[i] = Complex.ZERO;
                for (int j = 0; j < vector.length; ++j) {
                    result[i] = result[i].add(a[i][j].mul(vector[j]));
                }
            }
        }
        long s1 = System.currentTimeMillis();
        SimpleQuantumExecutionEnvironment.dbg("done applystep took " + (s1 - s0));
        return result;
    }

    private Complex[][] calculateStepMatrix(List<Gate> gates, int nQubits) {
        return Computations.calculateStepMatrix(gates, nQubits, this);
    }

    @Deprecated
    public Complex[][] tensor(Complex[][] a, Complex[][] b) {
        int d1 = a.length;
        int d2 = b.length;
        Complex[][] result = new Complex[d1 * d2][d1 * d2];
        for (int rowa = 0; rowa < d1; ++rowa) {
            for (int cola = 0; cola < d1; ++cola) {
                for (int rowb = 0; rowb < d2; ++rowb) {
                    for (int colb = 0; colb < d2; ++colb) {
                        result[d2 * rowa + rowb][d2 * cola + colb] = a[rowa][cola].mul(b[rowb][colb]);
                    }
                }
            }
        }
        return result;
    }

    private double[] calculateQubitStatesFromVector(Complex[] vectorresult) {
        int nq = (int)Math.round(Math.log(vectorresult.length) / Math.log(2.0));
        double[] answer = new double[nq];
        int ressize = 1 << nq;
        for (int i = 0; i < nq; ++i) {
            int pw = i;
            int div = 1 << pw;
            for (int j = 0; j < ressize; ++j) {
                int p1 = j / div;
                if (p1 % 2 != 1) continue;
                answer[i] = answer[i] + vectorresult[j].abssqr();
            }
        }
        return answer;
    }

    public Complex[][] createPermutationMatrix(int first, int second, int n) {
        Complex[][] iMatrix;
        Complex[][] swapMatrix = new Swap().getMatrix();
        Complex[][] answer = iMatrix = new Identity().getMatrix();
        int i = 1;
        if (first == 0) {
            answer = swapMatrix;
            ++i;
        }
        while (i < n) {
            if (i == first) {
                ++i;
                answer = this.tensor(answer, swapMatrix);
            } else {
                answer = this.tensor(answer, iMatrix);
            }
            ++i;
        }
        return answer;
    }
}

