/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.graph;

import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;

public class PreprocessorVertex
extends GraphVertex {
    private InputPreProcessor preProcessor;
    private InputType outputType;

    public PreprocessorVertex(InputPreProcessor preProcessor) {
        this(preProcessor, null);
    }

    public PreprocessorVertex(InputPreProcessor preProcessor, InputType outputType) {
        this.preProcessor = preProcessor;
        this.outputType = outputType;
    }

    @Override
    public GraphVertex clone() {
        return new PreprocessorVertex(this.preProcessor.clone());
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof PreprocessorVertex)) {
            return false;
        }
        return ((PreprocessorVertex)o).preProcessor.equals(this.preProcessor);
    }

    @Override
    public int hashCode() {
        return this.preProcessor.hashCode();
    }

    @Override
    public int numParams(boolean backprop) {
        return 0;
    }

    @Override
    public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams) {
        return new org.deeplearning4j.nn.graph.vertex.impl.PreprocessorVertex(graph, name, idx, this.preProcessor);
    }

    @Override
    public InputType getOutputType(InputType ... vertexInputs) throws InvalidInputTypeException {
        if (vertexInputs.length != 1) {
            throw new InvalidInputTypeException("Invalid input: Preprocessor vertex expects exactly one input");
        }
        if (this.outputType != null) {
            return this.outputType;
        }
        switch (vertexInputs[0].getType()) {
            case FF: {
                if (this.preProcessor instanceof FeedForwardToCnnPreProcessor) {
                    FeedForwardToCnnPreProcessor ffcnn = (FeedForwardToCnnPreProcessor)this.preProcessor;
                    return InputType.convolutional(ffcnn.getNumChannels(), ffcnn.getInputWidth(), ffcnn.getInputHeight());
                }
                if (this.preProcessor instanceof FeedForwardToRnnPreProcessor) {
                    return InputType.recurrent(((InputType.InputTypeFeedForward)vertexInputs[0]).getSize());
                }
                return InputType.feedForward(((InputType.InputTypeFeedForward)vertexInputs[0]).getSize());
            }
            case RNN: {
                if (this.preProcessor instanceof RnnToCnnPreProcessor) {
                    RnnToCnnPreProcessor ffcnn = (RnnToCnnPreProcessor)this.preProcessor;
                    return InputType.convolutional(ffcnn.getNumChannels(), ffcnn.getInputWidth(), ffcnn.getInputHeight());
                }
                if (this.preProcessor instanceof RnnToFeedForwardPreProcessor) {
                    return InputType.feedForward(((InputType.InputTypeRecurrent)vertexInputs[0]).getSize());
                }
                return InputType.recurrent(((InputType.InputTypeRecurrent)vertexInputs[0]).getSize());
            }
            case CNN: {
                if (this.preProcessor instanceof CnnToFeedForwardPreProcessor) {
                    CnnToFeedForwardPreProcessor p = (CnnToFeedForwardPreProcessor)this.preProcessor;
                    int outSize = p.getInputHeight() * p.getInputWidth() * p.getNumChannels();
                    return InputType.feedForward(outSize);
                }
                if (this.preProcessor instanceof CnnToRnnPreProcessor) {
                    CnnToRnnPreProcessor p = (CnnToRnnPreProcessor)this.preProcessor;
                    int outSize = p.getInputHeight() * p.getInputWidth() * p.getNumChannels();
                    return InputType.recurrent(outSize);
                }
                return vertexInputs[0];
            }
        }
        throw new RuntimeException("Unknown InputType: " + vertexInputs[0]);
    }

    public PreprocessorVertex() {
    }

    public InputPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(InputPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public void setOutputType(InputType outputType) {
        this.outputType = outputType;
    }

    public String toString() {
        return "PreprocessorVertex(preProcessor=" + this.getPreProcessor() + ", outputType=" + this.getOutputType(new InputType[0]) + ")";
    }
}

