From 0b568291d722aa5e39cd4a8fa05c03200dd280ab Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 16 Aug 2023 11:49:29 +0100 Subject: [PATCH] Marian: post-hack-fix correction (#25459) --- tests/models/marian/test_modeling_marian.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 8fd5e04a56..0ae0876e50 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -438,7 +438,11 @@ class MarianIntegrationTest(unittest.TestCase): ) self.assertEqual(self.model.device, model_inputs.input_ids.device) generated_ids = self.model.generate( - model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + num_beams=2, + max_length=128, + renormalize_logits=True, # Marian should always renormalize its logits. See #25459 ) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return generated_words