/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;

public class LSTM
extends RecurrentBlock {
    LSTM(Builder builder) {
        super(builder);
        this.gates = 4;
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        Shape stateShape;
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        Device device = inputs.head().getDevice();
        NDList rnnParams = new NDList();
        for (Parameter parameter : this.parameters.values()) {
            rnnParams.add(parameterStore.getValue(parameter, device, training));
        }
        NDArray input = inputs.head();
        if (inputs.size() == 1) {
            int batchIndex = this.batchFirst ? 0 : 1;
            stateShape = new Shape((long)this.numLayers * (long)this.getNumDirections(), input.size(batchIndex), this.stateSize);
            inputs.add(input.getManager().zeros(stateShape));
            inputs.add(input.getManager().zeros(stateShape));
        }
        if (inputs.size() == 2) {
            int batchIndex = this.batchFirst ? 0 : 1;
            stateShape = new Shape((long)this.numLayers * (long)this.getNumDirections(), input.size(batchIndex), this.stateSize);
            inputs.add(input.getManager().zeros(stateShape));
        }
        NDList outputs = ex.lstm(input, new NDList((NDArray)inputs.get(1), (NDArray)inputs.get(2)), rnnParams, this.hasBiases, this.numLayers, this.dropRate, training, this.bidirectional, this.batchFirst);
        if (this.returnState) {
            return outputs;
        }
        outputs.stream().skip(1L).forEach(NDArray::close);
        return new NDList((NDArray)outputs.get(0));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder
    extends RecurrentBlock.BaseBuilder<Builder> {
        @Override
        protected Builder self() {
            return this;
        }

        public LSTM build() {
            Preconditions.checkArgument(this.stateSize > 0L && this.numLayers > 0, "Must set stateSize and numStackedLayers");
            return new LSTM(this);
        }
    }
}

