nn_transformer_encoder {torch} | R Documentation |
Transformer Encoder Module (R torch)
Description
Implements a stack of transformer encoder layers, optionally with a final layer normalization.
Usage
nn_transformer_encoder(encoder_layer, num_layers, norm = NULL)
Arguments
encoder_layer |
(nn_module) an instance of |
num_layers |
(integer) the number of encoder layers to stack. |
norm |
(nn_module or NULL) optional layer normalization module to apply after the last layer (e.g., |
Details
This module replicates the given encoder_layer
num_layers
times to construct the Transformer encoder.
If a norm
module is provided, it will be applied to the output of the final encoder layer.
The forward pass sequentially applies each encoder layer to the input.
Value
An nn_module
of class nn_transformer_encoder
. Calling it on an input tensor of shape (S, N, E)
or (N, S, E)
(depending on batch_first
) returns the encoded output of the same shape.
Examples
if (torch_is_installed()) {
if (torch_is_installed()) {
layer <- nn_transformer_encoder_layer(d_model = 32, nhead = 4, batch_first = TRUE)
model <- nn_transformer_encoder(layer, num_layers = 2)
x <- torch_randn(8, 5, 32) # (batch, seq, feature) since batch_first=TRUE
y <- model(x) # output shape is (8, 5, 32)
}
}