Force use_cache to be False in PyTorch (#15385)

* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-02-08 16:20:53 +01:00
committed by GitHub
parent 0acd84f7cb
commit 6a5472a8e1
9 changed files with 27 additions and 0 deletions

View File

@@ -2832,6 +2832,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)