Correct eos_token_id settings in generate (#15403)
* Correct eos_token_id set in generate * Set eos_token_id in test * Correct eos_token_id set in generate * Set eos_token_id in test
This commit is contained in:
@@ -395,6 +395,9 @@ class EncoderDecoderMixin:
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
|
||||
Reference in New Issue
Block a user