diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index a52196ce99..a851d649be 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -640,9 +640,8 @@ class SelfAttention(nn.Module): reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - attn_weights_float = F.softmax(attn_weights, dim=-1) - attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,) - attn_weights = attn_weights_float.type_as(attn_weights) + attn_weights = F.softmax(attn_weights, dim=-1) + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) assert v is not None attn_output = torch.bmm(attn_probs, v) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index ccb1946080..f588d445b2 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, ) - lm_model = BartForMaskedLM(config).to(torch_device) - context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]) - summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]) + lm_model = BartForMaskedLM(config) + context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long() + summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long() logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(logits.shape, expected_shape) def test_generate_beam_search(self): - input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]]) + input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() config = BartConfig( vocab_size=self.vocab_size, d_model=24, @@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase): max_position_embeddings=48, output_past=True, ) - lm_model = BartForMaskedLM(config).to(torch_device) + lm_model = BartForMaskedLM(config) lm_model.eval() new_input_ids = lm_model.generate(