examples/seq2seq/run_eval.py fixes and docs (#5322)

This commit is contained in:
Sam Shleifer
2020-06-26 19:20:43 -04:00
committed by GitHub
parent 5543b30aa6
commit 393b8dc09a
5 changed files with 79 additions and 27 deletions

View File

@@ -253,9 +253,9 @@ class MBartIntegrationTests(unittest.TestCase):
with torch.no_grad():
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype)
result_slice = logits[0][0][:3]
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
expected_slice = [9.0078, 10.1113, 14.4787]
result_slice = logits[0][0][:3].tolist()
self.assertListEqual(expected_slice, result_slice)
@slow
def test_enro_generate(self):