diff --git a/tests/encoder_decoder/test_modeling_encoder_decoder.py b/tests/encoder_decoder/test_modeling_encoder_decoder.py index 46a1bf7b68..8412ccb389 100644 --- a/tests/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_encoder_decoder.py @@ -413,7 +413,10 @@ class EncoderDecoderMixin: enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) # Generate until max length - enc_dec_model.config.decoder.eos_token_id = None + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.to(torch_device) # Bert does not have a bos token id, so use pad_token_id instead diff --git a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py index de903c40c2..bedd72fe24 100644 --- a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -314,6 +314,12 @@ class TFEncoderDecoderMixin: encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + # Generate until max length + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None + # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id diff --git a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 4bc7c52943..c17792084d 100644 --- a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -347,7 +347,8 @@ class EncoderDecoderMixin: enc_dec_model.to(torch_device) # make sure EOS token is set to None to prevent early stopping of generation - enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): enc_dec_model.config.decoder.eos_token_id = None diff --git a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index f3a062744f..158aa4e5f0 100644 --- a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -300,6 +300,12 @@ class TFVisionEncoderDecoderMixin: encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + # Generate until max length + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None + # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id diff --git a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 90c61ab185..b867778ec9 100644 --- a/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -269,6 +269,12 @@ class EncoderDecoderMixin: def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + # Generate until max length + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.to(torch_device) inputs = pixel_values