Generate: missing generation config eos token setting in encoder-decoder tests (#29146)

This commit is contained in:
Joao Gante
2024-02-20 16:17:51 +00:00
committed by GitHub
parent 1c81132e80
commit 857fd8eaab
4 changed files with 8 additions and 0 deletions

View File

@@ -473,6 +473,8 @@ class EncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead

View File

@@ -377,6 +377,8 @@ class TFEncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(