set eos_token_id to None to generate until max length (#16989)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -413,7 +413,10 @@ class EncoderDecoderMixin:
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||
enc_dec_model.config.eos_token_id = None
|
||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
|
||||
Reference in New Issue
Block a user