/*
 * Decompiled with CFR 0.152.
 */
package org.apfloat.aparapi;

import com.aparapi.Kernel;
import org.apfloat.ApfloatRuntimeException;
import org.apfloat.spi.ArrayAccess;

class LongKernel
extends Kernel {
    private static ThreadLocal<LongKernel> kernel = new ThreadLocal<LongKernel>(){

        @Override
        public LongKernel initialValue() {
            return new LongKernel();
        }
    };
    public static final int TRANSFORM_ROWS = 1;
    public static final int INVERSE_TRANSFORM_ROWS = 2;
    private int stride;
    private int length;
    private long[] data;
    private int offset;
    private long[] wTable = new long[]{0L};
    private int[] permutationTable = new int[]{0};
    private int permutationTableLength;
    private long modulus;
    private double inverseModulus;
    public static final int TRANSPOSE = 3;
    public static final int PERMUTE = 4;
    private int n2;
    private int[] index = new int[]{0};
    private int indexCount;
    public static final int MULTIPLY_ELEMENTS = 5;
    private int startRow;
    private int startColumn;
    private int rows;
    private int columns;
    private long w;
    private long scaleFactor;
    public static final int TRANSFORM_COLUMNS = 6;
    public static final int INVERSE_TRANSFORM_COLUMNS = 7;
    private int op;
    private long ww;
    private long w1;
    private long w2;

    private LongKernel() {
    }

    public static LongKernel getInstance() {
        return kernel.get();
    }

    public void setLength(int length) {
        this.length = length;
    }

    public void setArrayAccess(ArrayAccess arrayAccess) throws ApfloatRuntimeException {
        this.data = arrayAccess.getLongData();
        this.offset = arrayAccess.getOffset();
        if (this.length != 0) {
            this.stride = arrayAccess.getLength() / this.length;
        }
    }

    public void setWTable(long[] wTable) {
        this.wTable = wTable;
    }

    public void setPermutationTable(int[] permutationTable) {
        this.permutationTable = permutationTable == null ? new int[1] : permutationTable;
        this.permutationTableLength = permutationTable == null ? 0 : permutationTable.length;
    }

    private void columnTableFNT() {
        int istep = 0;
        int mmax = 0;
        int r = 0;
        long[] data = this.data;
        int offset = this.offset + this.getGlobalId();
        int stride = this.stride;
        int nn = this.length;
        if (nn >= 2) {
            r = 1;
            for (mmax = nn >> 1; mmax > 0; mmax >>= 1) {
                istep = mmax << 1;
                for (int i = offset; i < offset + nn * stride; i += istep * stride) {
                    int j = i + mmax * stride;
                    long a = data[i];
                    long b = data[j];
                    data[i] = this.modAdd(a, b);
                    data[j] = this.modSubtract(a, b);
                }
                int t = r;
                for (int m = 1; m < mmax; ++m) {
                    for (int i = offset + m * stride; i < offset + nn * stride; i += istep * stride) {
                        int j = i + mmax * stride;
                        long a = data[i];
                        long b = data[j];
                        data[i] = this.modAdd(a, b);
                        data[j] = this.modMultiply(this.wTable[t], this.modSubtract(a, b));
                    }
                    t += r;
                }
                r <<= 1;
            }
            if (this.permutationTableLength > 0) {
                this.columnScramble(offset);
            }
        }
    }

    private void inverseColumnTableFNT() {
        int istep = 0;
        int mmax = 0;
        int r = 0;
        long[] data = this.data;
        int offset = this.offset + this.getGlobalId();
        int stride = this.stride;
        int nn = this.length;
        if (nn >= 2) {
            if (this.permutationTableLength > 0) {
                this.columnScramble(offset);
            }
            r = nn;
            mmax = 1;
            while (nn > mmax) {
                istep = mmax << 1;
                r >>= 1;
                for (int i = offset; i < offset + nn * stride; i += istep * stride) {
                    int j = i + mmax * stride;
                    long wTemp = data[j];
                    data[j] = this.modSubtract(data[i], wTemp);
                    data[i] = this.modAdd(data[i], wTemp);
                }
                int t = r;
                for (int m = 1; m < mmax; ++m) {
                    for (int i = offset + m * stride; i < offset + nn * stride; i += istep * stride) {
                        int j = i + mmax * stride;
                        long wTemp = this.modMultiply(this.wTable[t], data[j]);
                        data[j] = this.modSubtract(data[i], wTemp);
                        data[i] = this.modAdd(data[i], wTemp);
                    }
                    t += r;
                }
                mmax = istep;
            }
        }
    }

    private void columnScramble(int offset) {
        for (int k = 0; k < this.permutationTableLength; k += 2) {
            int i = offset + this.permutationTable[k] * this.stride;
            int j = offset + this.permutationTable[k + 1] * this.stride;
            long tmp = this.data[i];
            this.data[i] = this.data[j];
            this.data[j] = tmp;
        }
    }

    private long modMultiply(long a, long b) {
        long r = a * b - this.modulus * (long)((double)a * (double)b * this.inverseModulus);
        r = (r -= this.modulus * (long)((int)((double)r * this.inverseModulus))) >= this.modulus ? r - this.modulus : r;
        r = r < 0L ? r + this.modulus : r;
        return r;
    }

    private long modAdd(long a, long b) {
        long r = a + b;
        return r >= this.modulus ? r - this.modulus : r;
    }

    private long modSubtract(long a, long b) {
        long r = a - b;
        return r < 0L ? r + this.modulus : r;
    }

    public void setModulus(long modulus) {
        this.inverseModulus = 1.0 / (double)modulus;
        this.modulus = modulus;
    }

    public long getModulus() {
        return this.modulus;
    }

    public void setN2(int n2) {
        this.n2 = n2;
    }

    public void setIndex(int[] index) {
        this.index = index;
    }

    public void setIndexCount(int indexCount) {
        this.indexCount = indexCount;
    }

    private void transpose() {
        int j;
        int i = this.getGlobalId(0);
        if (i < (j = this.getGlobalId(1))) {
            int position1 = this.offset + j * this.n2 + i;
            int position2 = this.offset + i * this.n2 + j;
            long tmp = this.data[position1];
            this.data[position1] = this.data[position2];
            this.data[position2] = tmp;
        }
    }

    private void permute() {
        int j = this.getGlobalId();
        for (int i = 0; i < this.indexCount; ++i) {
            int o = this.index[i];
            long tmp = this.data[this.offset + this.n2 * o + j];
            ++i;
            while (this.index[i] != 0) {
                int m = this.index[i];
                this.data[this.offset + this.n2 * o + j] = this.data[this.offset + this.n2 * m + j];
                o = m;
                ++i;
            }
            this.data[this.offset + this.n2 * o + j] = tmp;
        }
    }

    public void setStartRow(int startRow) {
        this.startRow = startRow;
    }

    public void setStartColumn(int startColumn) {
        this.startColumn = startColumn;
    }

    public void setRows(int rows) {
        this.rows = rows;
    }

    public void setColumns(int columns) {
        this.columns = columns;
    }

    public void setW(long w) {
        this.w = w;
    }

    public void setScaleFactor(long scaleFactor) {
        this.scaleFactor = scaleFactor;
    }

    private void multiplyElements() {
        long[] data = this.data;
        int position = this.offset + this.getGlobalId();
        long rowFactor = this.modPow(this.w, this.startRow);
        long columnFactor = this.modPow(this.w, (long)this.startColumn + (long)this.getGlobalId());
        long rowStartFactor = this.modMultiply(this.scaleFactor, this.modPow(rowFactor, (long)this.startColumn + (long)this.getGlobalId()));
        for (int i = 0; i < this.rows; ++i) {
            data[position] = this.modMultiply(data[position], rowStartFactor);
            position += this.columns;
            rowStartFactor = this.modMultiply(rowStartFactor, columnFactor);
        }
    }

    private long modPow(long a, long n) {
        if (n == 0L) {
            return 1L;
        }
        if (n < 0L) {
            n = this.getModulus() - 1L + n;
        }
        long exponent = n;
        while ((exponent & 1L) == 0L) {
            a = this.modMultiply(a, a);
            exponent >>= 1;
        }
        long r = a;
        exponent >>= 1;
        while (exponent > 0L) {
            a = this.modMultiply(a, a);
            if ((exponent & 1L) != 0L) {
                r = this.modMultiply(r, a);
            }
            exponent >>= 1;
        }
        return r;
    }

    public void setOp(int op) {
        this.op = op;
    }

    public void setWw(long ww) {
        this.ww = ww;
    }

    public void setW1(long w1) {
        this.w1 = w1;
    }

    public void setW2(long w2) {
        this.w2 = w2;
    }

    public void run() {
        if (this.op == 1) {
            this.columnTableFNT();
        } else if (this.op == 2) {
            this.inverseColumnTableFNT();
        } else if (this.op == 3) {
            this.transpose();
        } else if (this.op == 4) {
            this.permute();
        } else if (this.op == 5) {
            this.multiplyElements();
        } else if (this.op == 6 || this.op == 7) {
            this.transformColumns();
        }
    }

    private void transformColumns() {
        int i = this.getGlobalId();
        long tmp1 = this.modPow(this.w, (long)this.startColumn + (long)i);
        long tmp2 = this.modPow(this.ww, (long)this.startColumn + (long)i);
        long x0 = this.data[this.offset + i];
        long x1 = this.data[this.offset + this.columns + i];
        long x2 = this.data[this.offset + 2 * this.columns + i];
        if (this.op == 7) {
            x1 = this.modMultiply(x1, tmp1);
            x2 = this.modMultiply(x2, tmp2);
        }
        long t = this.modAdd(x1, x2);
        x2 = this.modSubtract(x1, x2);
        x0 = this.modAdd(x0, t);
        t = this.modMultiply(t, this.w1);
        x2 = this.modMultiply(x2, this.w2);
        t = this.modAdd(t, x0);
        x1 = this.modAdd(t, x2);
        x2 = this.modSubtract(t, x2);
        if (this.op == 6) {
            x1 = this.modMultiply(x1, tmp1);
            x2 = this.modMultiply(x2, tmp2);
        }
        this.data[this.offset + i] = x0;
        this.data[this.offset + this.columns + i] = x1;
        this.data[this.offset + 2 * this.columns + i] = x2;
    }
}

