diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index e822609ecc..90bc5e19b7 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1049,6 +1049,8 @@ class GenerationMixin: pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if eos_token_id is None and hasattr(self.config, "decoder"): + eos_token_id = self.config.decoder.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 9c4ab74c72..2c9de822c6 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -395,6 +395,9 @@ class EncoderDecoderMixin: def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + # Generate until max length + enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.to(torch_device) # Bert does not have a bos token id, so use pad_token_id instead