Marian: post-hack-fix correction (#25459)

This commit is contained in:
Joao Gante
2023-08-16 11:49:29 +01:00
committed by GitHub
parent 5ccf343aeb
commit 0b568291d7

View File

@@ -438,7 +438,11 @@ class MarianIntegrationTest(unittest.TestCase):
) )
self.assertEqual(self.model.device, model_inputs.input_ids.device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
num_beams=2,
max_length=128,
renormalize_logits=True, # Marian should always renormalize its logits. See #25459
) )
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words return generated_words