/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rbm;

import org.deeplearning4j.rbm.RBM;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class ConvolutionalRBM
extends RBM {
    private static final long serialVersionUID = 6868729665328916878L;
    private int numFilters;
    private int poolRows;
    private int poolColumns;

    public DoubleMatrix visibleExpectation(DoubleMatrix visible, double bias) {
        DoubleMatrix filterMatrix = new DoubleMatrix(this.numFilters);
        for (int k = 0; k < this.numFilters; ++k) {
            DoubleMatrix next = MatrixUtil.convolution2D(visible, visible.columns, visible.rows, this.getW().getRow(k), this.getW().rows, this.getW().columns).add(this.getvBias().add(bias)).transpose();
            filterMatrix.putRow(k, next);
        }
        filterMatrix = this.pool(filterMatrix);
        filterMatrix.addi(1.0);
        filterMatrix = MatrixUtil.oneDiv(filterMatrix);
        return MatrixUtil.sigmoid(filterMatrix);
    }

    public DoubleMatrix pooledExpectation(DoubleMatrix visible, double bias) {
        DoubleMatrix filterMatrix = new DoubleMatrix(this.numFilters);
        for (int k = 0; k < this.numFilters; ++k) {
            DoubleMatrix next = MatrixUtil.convolution2D(visible, visible.columns, visible.rows, this.getW().getRow(k), this.getW().rows, this.getW().columns).add(this.gethBias().add(bias)).transpose();
            filterMatrix.putRow(k, next);
        }
        filterMatrix = this.pool(filterMatrix);
        filterMatrix.addi(1.0);
        filterMatrix = MatrixUtil.oneDiv(filterMatrix);
        return filterMatrix;
    }

    public DoubleMatrix pool(DoubleMatrix hidden) {
        DoubleMatrix active = MatrixFunctions.exp((DoubleMatrix)hidden.transpose());
        DoubleMatrix pool = DoubleMatrix.zeros((int)active.rows, (int)active.columns);
        int maxColumn = (int)Math.ceil(this.poolColumns / hidden.columns);
        for (int j = 0; j < maxColumn; ++j) {
            int beginColumnSlice = j * this.poolColumns;
            int endColumnSlice = (j + 1) * this.poolColumns;
            int maxRow = (int)Math.ceil(this.poolRows / hidden.rows);
            for (int i = 0; i < maxRow; ++i) {
                int beginRowSlice = i * this.poolRows;
                int endRowSlice = (i + 1) * this.poolRows;
                DoubleMatrix subSlice = active.get(new int[]{beginRowSlice, endRowSlice}, new int[]{beginColumnSlice, endColumnSlice}).rowSums().rowSums();
                pool.put(new int[]{beginRowSlice, endRowSlice}, new int[]{beginColumnSlice, endColumnSlice}, subSlice);
            }
        }
        return pool.transpose();
    }
}

