diff --git a/examples/summarization/bertabs/modeling_bertabs.py b/examples/summarization/bertabs/modeling_bertabs.py index 0691403186..117180d69b 100644 --- a/examples/summarization/bertabs/modeling_bertabs.py +++ b/examples/summarization/bertabs/modeling_bertabs.py @@ -844,7 +844,7 @@ class Translator(object): dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step) # Generator forward. - log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0)) + log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 408ca0f31d..73d2d22898 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -223,9 +223,7 @@ class EncoderLayer(nn.Module): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - x, attn_weights = self.self_attn( - query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions, - ) + x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) @@ -378,7 +376,7 @@ class DecoderLayer(nn.Module): layer_state = {} # next line mutates layer state x, self_attn_weights = self.self_attn( - query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask, + query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -393,7 +391,6 @@ class DecoderLayer(nn.Module): key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state static_kv=True, - need_weights=False, # not returning it so why compute it ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -548,16 +545,12 @@ class SelfAttention(nn.Module): self, embed_dim, num_heads, - kdim=None, - vdim=None, dropout=0.0, bias=True, encoder_decoder_attention=False, # otherwise self_attention ): super().__init__() self.embed_dim = embed_dim - self.kdim = kdim if kdim is not None else embed_dim - self.vdim = vdim if vdim is not None else embed_dim self.num_heads = num_heads self.dropout = dropout @@ -566,13 +559,8 @@ class SelfAttention(nn.Module): self.scaling = self.head_dim ** -0.5 self.encoder_decoder_attention = encoder_decoder_attention - qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART - - assert self.encoder_decoder_attention or qkv_same_dim, ( - "Self-attention requires query, key and " "value to be of the same size" - ) - self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) - self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" @@ -587,7 +575,6 @@ class SelfAttention(nn.Module): value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - need_weights: bool = False, static_kv: bool = False, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -598,8 +585,6 @@ class SelfAttention(nn.Module): key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. - need_weights (bool, optional): return the attention weights, - averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 1ea46a091b..4ac9442a53 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -141,13 +141,13 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): _check_var(model.encoder.layers[0].fc1) _check_var(model.encoder.embed_positions) - decoder_features_with_created_mask = model.forward(**inputs_dict)[0] - decoder_features_with_passed_mask = model.forward( + decoder_features_with_created_mask = model(**inputs_dict)[0] + decoder_features_with_passed_mask = model( decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict )[0] _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) useless_mask = torch.zeros_like(decoder_attn_mask) - decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0] + decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0] self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions self.assertEqual( decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model) @@ -156,7 +156,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item()) # Test different encoder attention masks - decoder_features_with_long_encoder_mask = model.forward( + decoder_features_with_long_encoder_mask = model( inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long() )[0] _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask) @@ -237,7 +237,7 @@ class BartHeadTests(unittest.TestCase): decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_model = BartForConditionalGeneration(config) lm_model.to(torch_device) - loss, logits, enc_features = lm_model.forward( + loss, logits, enc_features = lm_model( input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids ) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) @@ -259,7 +259,7 @@ class BartHeadTests(unittest.TestCase): lm_model = BartForConditionalGeneration(config).to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) - loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary) + loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(logits.shape, expected_shape) @@ -388,7 +388,7 @@ class BartModelIntegrationTest(unittest.TestCase): input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) with torch.no_grad(): - output = model.forward(**inputs_dict)[0] + output = model(**inputs_dict)[0] expected_shape = torch.Size((1, 11, 1024)) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor( @@ -408,7 +408,7 @@ class BartModelIntegrationTest(unittest.TestCase): inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) # Test that model hasn't changed with torch.no_grad(): - batched_logits, features = model.forward(**inputs_dict) + batched_logits, features = model(**inputs_dict) expected_shape = torch.Size((2, 3)) self.assertEqual(batched_logits.shape, expected_shape) expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device) @@ -419,7 +419,7 @@ class BartModelIntegrationTest(unittest.TestCase): inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad) with torch.no_grad(): - logits2 = model.forward(**inputs_dict)[0] + logits2 = model(**inputs_dict)[0] _assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE) _assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)