Correct eos_token_id settings in generate (#15403)

* Correct eos_token_id set in generate

* Set eos_token_id in test

* Correct eos_token_id set in generate

* Set eos_token_id in test
This commit is contained in:
CHI LIU
2022-02-03 07:24:40 +08:00
committed by GitHub
parent 39b5d1a63a
commit 5ec368d79e
2 changed files with 5 additions and 0 deletions

View File

@@ -395,6 +395,9 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
enc_dec_model.config.decoder.eos_token_id = None
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead