[mbart] Fix fp16 testing logic (#4949)
This commit is contained in:
@@ -249,7 +249,7 @@ class MBartIntegrationTests(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits, *other_stuff = model(**net_input)
|
logits, *other_stuff = model(**net_input)
|
||||||
|
|
||||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
|
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device, dtype=model.dtype)
|
||||||
result_slice = logits[0][0][:3]
|
result_slice = logits[0][0][:3]
|
||||||
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(expected_slice, result_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
@@ -293,9 +293,6 @@ class MBartIntegrationTests(unittest.TestCase):
|
|||||||
expected_shape = (*summary.shape, config.vocab_size)
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
class MBartTokenizerTests(MBartIntegrationTests):
|
|
||||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||||
batch = self.tokenizer.prepare_translation_batch(
|
batch = self.tokenizer.prepare_translation_batch(
|
||||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||||
|
|||||||
Reference in New Issue
Block a user