set eos_token_id to None to generate until max length (#16989)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-04-28 19:47:38 +02:00
committed by GitHub
parent 01562dac7e
commit 5af5735f62
5 changed files with 24 additions and 2 deletions

View File

@@ -347,7 +347,8 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device)
# make sure EOS token is set to None to prevent early stopping of generation
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None