Generate: missing generation config eos token setting in encoder-decoder tests (#29146)
This commit is contained in:
@@ -473,6 +473,8 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.config.eos_token_id = None
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||||
|
enc_dec_model.generation_config.eos_token_id = None
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
|
|||||||
@@ -377,6 +377,8 @@ class TFEncoderDecoderMixin:
|
|||||||
enc_dec_model.config.eos_token_id = None
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||||
|
enc_dec_model.generation_config.eos_token_id = None
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
|
|||||||
@@ -351,6 +351,8 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.config.eos_token_id = None
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||||
|
enc_dec_model.generation_config.eos_token_id = None
|
||||||
|
|
||||||
inputs = input_values if input_features is None else input_features
|
inputs = input_values if input_features is None else input_features
|
||||||
|
|
||||||
|
|||||||
@@ -308,6 +308,8 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
enc_dec_model.config.eos_token_id = None
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||||
|
enc_dec_model.generation_config.eos_token_id = None
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
|
|||||||
Reference in New Issue
Block a user