Transformer() can get the 2D or 3D tensor of the one or more elements computed by Transformer from the 2D or 3D tensor of one or more elements as shown below: import torch from torch import nn tensor1 = torch.tensor([[8., -3., 0., 1.]]) tensor2 = torch.tensor([[5., 9., -4., 8.], [-2., 7., 3., 6.]]) tensor1.requires_grad tensor2.requires_grad # False torch.manual_seed(42) tran1 = nn.Transformer(d_model=4, nhead=2) tensor3 = tran1(src=tensor1, tgt=tensor2) tensor3 # tensor([[1.5608, 0.1450, -0.6434, -1.0624], # [0.8815, 1.0994, -1.1523, -0.8286]], # grad_fn=<NativeLayerNormBackward0>) tensor3.requires_grad # True tran1 # Transformer( # (encoder): TransformerEncoder( # (layers): ModuleList( # (0-5): 6 x TransformerEncoderLayer( # (self_attn): MultiheadAttention( #…