check decoder_inputs_embeds is None before shifting labels (#19671)
This commit is contained in:
@@ -2938,7 +2938,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -2833,7 +2833,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
|
||||
Reference in New Issue
Block a user