Correct eos_token_id settings in generate (#15403)
* Correct eos_token_id set in generate * Set eos_token_id in test * Correct eos_token_id set in generate * Set eos_token_id in test
This commit is contained in:
@@ -1049,6 +1049,8 @@ class GenerationMixin:
|
|||||||
|
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
if eos_token_id is None and hasattr(self.config, "decoder"):
|
||||||
|
eos_token_id = self.config.decoder.eos_token_id
|
||||||
|
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|||||||
@@ -395,6 +395,9 @@ class EncoderDecoderMixin:
|
|||||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
|
|
||||||
|
# Generate until max length
|
||||||
|
enc_dec_model.config.decoder.eos_token_id = None
|
||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
|
|||||||
Reference in New Issue
Block a user