check decoder_inputs_embeds is None before shifting labels (#19671)
This commit is contained in:
@@ -1352,7 +1352,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1319,7 +1319,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1371,7 +1371,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1286,7 +1286,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1351,7 +1351,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -2428,7 +2428,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -2445,7 +2445,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1432,7 +1432,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1388,7 +1388,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1347,7 +1347,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
|
||||
@@ -1387,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
|
||||
@@ -1393,7 +1393,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1397,7 +1397,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
||||
labels,
|
||||
)
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1605,7 +1605,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1314,7 +1314,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
|
||||
@@ -1341,7 +1341,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1405,7 +1405,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1293,7 +1293,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -1183,7 +1183,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if labels is not None:
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -2938,7 +2938,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
@@ -2833,7 +2833,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
if use_cache:
|
||||
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||
use_cache = False
|
||||
if decoder_input_ids is None:
|
||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
outputs = self.model(
|
||||
|
||||
Reference in New Issue
Block a user