Marian: post-hack-fix correction (#25459)
This commit is contained in:
@@ -438,7 +438,11 @@ class MarianIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
model_inputs.input_ids,
|
||||||
|
attention_mask=model_inputs.attention_mask,
|
||||||
|
num_beams=2,
|
||||||
|
max_length=128,
|
||||||
|
renormalize_logits=True, # Marian should always renormalize its logits. See #25459
|
||||||
)
|
)
|
||||||
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
return generated_words
|
return generated_words
|
||||||
|
|||||||
Reference in New Issue
Block a user