Fix slow tests v2 (#8746)

* Fix BART test

* Fix MBART tests

* Remove erroneous line from yaml

* Update tests/test_modeling_bart.py

* Quality
This commit is contained in:
Lysandre Debut
2020-11-24 09:35:12 -05:00
committed by GitHub
parent 2c83b3c38d
commit 6fdd0bb231
3 changed files with 4 additions and 3 deletions

View File

@@ -492,7 +492,9 @@ class BartModelIntegrationTests(unittest.TestCase):
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
# Test that model hasn't changed
with torch.no_grad():
batched_logits, features = model(**inputs_dict)
outputs = model(**inputs_dict)
batched_logits = outputs[0]
expected_shape = torch.Size((2, 3))
self.assertEqual(batched_logits.shape, expected_shape)
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)