Marian: post-hack-fix correction (#25459)
This commit is contained in:
@@ -438,7 +438,11 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
num_beams=2,
|
||||
max_length=128,
|
||||
renormalize_logits=True, # Marian should always renormalize its logits. See #25459
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return generated_words
|
||||
|
||||
Reference in New Issue
Block a user