Force use_cache to be False in PyTorch (#15385)
* use_cache = False for PT models if labels is passed * Fix for BigBirdPegasusForConditionalGeneration * add warning if users specify use_cache=True * Use logger.warning instead of warnings.warn Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1318,6 +1318,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -2513,6 +2513,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -1287,6 +1287,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -1258,6 +1258,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -2366,6 +2366,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -1291,6 +1291,9 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -1314,6 +1314,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1381,6 +1381,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -2832,6 +2832,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||||||
|
use_cache = False
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user