Correct eos_token_id settings in generate (#15403)
* Correct eos_token_id set in generate * Set eos_token_id in test * Correct eos_token_id set in generate * Set eos_token_id in test
This commit is contained in:
@@ -1049,6 +1049,8 @@ class GenerationMixin:
|
||||
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
if eos_token_id is None and hasattr(self.config, "decoder"):
|
||||
eos_token_id = self.config.decoder.eos_token_id
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
@@ -395,6 +395,9 @@ class EncoderDecoderMixin:
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
|
||||
Reference in New Issue
Block a user