set eos_token_id to None to generate until max length (#16989)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -413,7 +413,10 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
# Generate until max length
|
# Generate until max length
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||||
|
enc_dec_model.config.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
|
|||||||
@@ -314,6 +314,12 @@ class TFEncoderDecoderMixin:
|
|||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
# Generate until max length
|
||||||
|
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||||
|
enc_dec_model.config.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||||
|
|||||||
@@ -347,7 +347,8 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
# make sure EOS token is set to None to prevent early stopping of generation
|
# make sure EOS token is set to None to prevent early stopping of generation
|
||||||
enc_dec_model.config.eos_token_id = None
|
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||||
|
enc_dec_model.config.eos_token_id = None
|
||||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
enc_dec_model.config.decoder.eos_token_id = None
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
|
||||||
|
|||||||
@@ -300,6 +300,12 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
# Generate until max length
|
||||||
|
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||||
|
enc_dec_model.config.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||||
|
|||||||
@@ -269,6 +269,12 @@ class EncoderDecoderMixin:
|
|||||||
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
# Generate until max length
|
||||||
|
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||||
|
enc_dec_model.config.eos_token_id = None
|
||||||
|
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||||
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
inputs = pixel_values
|
inputs = pixel_values
|
||||||
|
|||||||
Reference in New Issue
Block a user