From 4e4403c9b44324671cb795df2ef30e70fe3b606e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 19 Mar 2020 11:56:54 -0400 Subject: [PATCH] [BART] torch 1.0 compatibility (#3322) * config.activation_function --- src/transformers/activations.py | 6 +-- src/transformers/configuration_bart.py | 2 + src/transformers/modeling_bart.py | 58 ++++++++++---------------- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index dd0a7aa377..7968b88ba9 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -44,8 +44,4 @@ def get_activation(activation_string): if activation_string in ACT2FN: return ACT2FN[activation_string] else: - raise KeyError( - "function {} not found in ACT2FN mapping {} or torch.nn.functional".format( - activation_string, list(ACT2FN.keys()) - ) - ) + raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index f6733a9bc4..3bb26ead68 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig): def __init__( self, activation_dropout=0.0, + activation_function="gelu", vocab_size=50265, bos_token_id=0, pad_token_id=1, @@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig): self.decoder_attention_heads = decoder_attention_heads self.max_position_embeddings = max_position_embeddings self.init_std = init_std # Normal(0, this parameter) + self.activation_function = activation_function # 3 Types of Dropout self.attention_dropout = attention_dropout diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index ee8e7c54cc..b1c3e03466 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F from torch import Tensor, nn +from .activations import ACT2FN from .configuration_bart import BartConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids @@ -196,7 +197,7 @@ class EncoderLayer(nn.Module): ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout - self.activation_fn = F.gelu + self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) @@ -278,8 +279,8 @@ class BartEncoder(nn.Module): # check attention mask and invert if attention_mask is not None: assert attention_mask.dim() == 2 - attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE - assert attention_mask.max() <= 0 + attention_mask = attention_mask.eq(0) + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_ids) x = inputs_embeds + embed_pos @@ -318,7 +319,7 @@ class DecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout - self.activation_fn = F.gelu + self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = LayerNorm(self.embed_dim) @@ -334,13 +335,7 @@ class DecoderLayer(nn.Module): self.final_layer_norm = LayerNorm(self.embed_dim) def forward( - self, - x, - encoder_hidden_states, - encoder_attn_mask=None, - layer_state=None, - attention_mask=None, - need_attn_weights=False, + self, x, encoder_hidden_states, encoder_attn_mask=None, layer_state=None, attention_mask=None, ): residual = x @@ -437,9 +432,7 @@ class BartDecoder(nn.Module): # check attention mask and invert if encoder_padding_mask is not None: assert encoder_padding_mask.dim() == 2 - - encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0 - assert encoder_padding_mask.max() <= 0 + encoder_padding_mask = encoder_padding_mask.eq(0) # embed positions positions = self.embed_positions(input_ids, generation_mode=generation_mode) @@ -469,12 +462,7 @@ class BartDecoder(nn.Module): layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None x, layer_self_attn, layer_past = decoder_layer( - x, - encoder_hidden_states, - encoder_padding_mask, - layer_state=layer_state, - attention_mask=combined_mask, - need_attn_weights=self.output_attentions, + x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask, ) if self.output_past: @@ -598,7 +586,7 @@ class SelfAttention(nn.Module): if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) + reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) @@ -648,22 +636,20 @@ class SelfAttention(nn.Module): static_kv: bool, ) -> Optional[Tensor]: # saved key padding masks have shape (bsz, seq_len) - if prev_key_padding_mask is not None and static_kv: - new_key_padding_mask = prev_key_padding_mask - elif prev_key_padding_mask is not None and key_padding_mask is not None: - new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) - # During incremental decoding, as the padding token enters and - # leaves the frame, there will be a time when prev or current is None - elif prev_key_padding_mask is not None: - filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) - if prev_key_padding_mask.is_cuda: - filler = filler.to(prev_key_padding_mask.device) - new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + if prev_key_padding_mask is not None: + if static_kv: + new_key_padding_mask = prev_key_padding_mask + else: + new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) + elif key_padding_mask is not None: - filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) - if key_padding_mask.is_cuda: - filler = filler.cuda() - new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + filler = torch.zeros( + batch_size, + src_len - key_padding_mask.size(1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1) else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask