[BART] Remove unused kwargs (#3279)
* Remove unused kwargs * dont call forward in tests
This commit is contained in:
@@ -141,13 +141,13 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
_check_var(model.encoder.layers[0].fc1)
|
||||
_check_var(model.encoder.embed_positions)
|
||||
|
||||
decoder_features_with_created_mask = model.forward(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model.forward(
|
||||
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
self.assertEqual(
|
||||
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
|
||||
@@ -156,7 +156,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
||||
|
||||
# Test different encoder attention masks
|
||||
decoder_features_with_long_encoder_mask = model.forward(
|
||||
decoder_features_with_long_encoder_mask = model(
|
||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
@@ -237,7 +237,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(
|
||||
loss, logits, enc_features = lm_model(
|
||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||
)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
@@ -259,7 +259,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
@@ -388,7 +388,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||
with torch.no_grad():
|
||||
output = model.forward(**inputs_dict)[0]
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = torch.Size((1, 11, 1024))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
@@ -408,7 +408,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||
# Test that model hasn't changed
|
||||
with torch.no_grad():
|
||||
batched_logits, features = model.forward(**inputs_dict)
|
||||
batched_logits, features = model(**inputs_dict)
|
||||
expected_shape = torch.Size((2, 3))
|
||||
self.assertEqual(batched_logits.shape, expected_shape)
|
||||
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)
|
||||
@@ -419,7 +419,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||
with torch.no_grad():
|
||||
logits2 = model.forward(**inputs_dict)[0]
|
||||
logits2 = model(**inputs_dict)[0]
|
||||
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user