From 5620033115e571013325d017bcca92991b0a4ace Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 11 Jun 2020 22:11:34 -0400 Subject: [PATCH] [mbart] Fix fp16 testing logic (#4949) --- tests/test_modeling_bart.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 8e5daa96fd..43e9f3f6dc 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -249,7 +249,7 @@ class MBartIntegrationTests(unittest.TestCase): with torch.no_grad(): logits, *other_stuff = model(**net_input) - expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device) + expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype) result_slice = logits[0][0][:3] self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE)) @@ -293,9 +293,6 @@ class MBartIntegrationTests(unittest.TestCase): expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(logits.shape, expected_shape) - -@require_torch -class MBartTokenizerTests(MBartIntegrationTests): def test_enro_tokenizer_prepare_translation_batch(self): batch = self.tokenizer.prepare_translation_batch( self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),