class TorchText::NN::ScaledDotProduct
Public Class Methods
new(dropout: 0.0, batch_first: false)
click to toggle source
Calls superclass method
# File lib/torchtext/nn/scaled_dot_product.rb, line 4 def initialize(dropout: 0.0, batch_first: false) super() @dropout = dropout @batch_first = batch_first end
Public Instance Methods
forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
click to toggle source
# File lib/torchtext/nn/scaled_dot_product.rb, line 10 def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil) if @batch_first query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2) end if !bias_k.nil? && !bias_v.nil? unless key.size(-1) == bias_k.size(-1) && key.size(-2) == bias_k.size(-2) && bias_k.size(-3) == 1 raise "Shape of bias_k is not supported" end unless value.size(-1) == bias_v.size(-1) && value.size(-2) == bias_v.size(-2) && bias_v.size(-3) == 1 raise "Shape of bias_v is not supported" end key = Torch.cat([key, bias_k]) value = Torch.cat([value, bias_v]) if !attn_mask.nil? attn_mask = Torch::NN::Functional.pad(attn_mask, [0, 1]) end end tgt_len, head_dim = query.size(-3), query.size(-1) unless query.size(-1) == key.size(-1) && key.size(-1) == value.size(-1) raise "The feature dim of query, key, value must be equal." end unless key.size() == value.size() raise "Shape of key, value must match" end src_len = key.size(-3) batch_heads = [query.size(-2), key.size(-2)].max # Scale query query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3) query = query * (head_dim.to_f ** -0.5) if !attn_mask.nil? if attn_mask.dim() != 3 raise RuntimeError, "attn_mask must be a 3D tensor." end if (attn_mask.size(-1) != src_len) || (attn_mask.size(-2) != tgt_len) || (attn_mask.size(-3) != 1 && attn_mask.size(-3) != batch_heads) raise RuntimeError, "The size of the attn_mask is not correct." end if attn_mask.dtype != :bool raise RuntimeError, "Only bool tensor is supported for attn_mask" end end # Dot product of q, k attn_output_weights = Torch.matmul(query, key.transpose(-2, -1)) if !attn_mask.nil? # TODO confirm last argument attn_output_weights.masked_fill!(attn_mask, -1e8, nil) end attn_output_weights = Torch::NN::Functional.softmax(attn_output_weights, dim: -1) attn_output_weights = Torch::NN::Functional.dropout(attn_output_weights, p: @dropout, training: @training) attn_output = Torch.matmul(attn_output_weights, value) if @batch_first [attn_output, attn_output_weights] else [attn_output.transpose(-3, -2), attn_output_weights] end end