From 7e662e6a3be0ece455b4c4ae2c3348beab11bad5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 19 Jan 2021 17:11:22 -0500 Subject: [PATCH] Fix model templates and use less than 119 chars (#9684) * Fix model templates and use less than 119 chars * Missing new line --- src/transformers/models/bart/modeling_bart.py | 3 ++- src/transformers/models/bert/modeling_bert.py | 3 ++- src/transformers/models/blenderbot/modeling_blenderbot.py | 3 ++- .../models/blenderbot_small/modeling_blenderbot_small.py | 3 ++- src/transformers/models/electra/modeling_electra.py | 3 ++- src/transformers/models/gpt2/modeling_gpt2.py | 3 ++- src/transformers/models/layoutlm/modeling_layoutlm.py | 3 ++- src/transformers/models/led/modeling_led.py | 3 ++- src/transformers/models/marian/modeling_marian.py | 3 ++- src/transformers/models/mbart/modeling_mbart.py | 3 ++- src/transformers/models/pegasus/modeling_pegasus.py | 3 ++- src/transformers/models/roberta/modeling_roberta.py | 3 ++- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 8 ++++++++ 13 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 6a6fac4690..c8462276bf 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -997,7 +997,8 @@ class BartDecoder(BartPretrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 72c795e621..0181757074 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -544,7 +544,8 @@ class BertEncoder(nn.Module): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 2b962c5e8f..52e62863eb 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -959,7 +959,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8003505d59..8d576b8b3a 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -959,7 +959,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 0374871a77..9de0be3c49 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -541,7 +541,8 @@ class ElectraEncoder(nn.Module): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 175f9b1c42..cebf705d71 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -721,7 +721,8 @@ class GPT2Model(GPT2PreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 31c46fd9f9..96046c07da 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -471,7 +471,8 @@ class LayoutLMEncoder(nn.Module): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index c79cc2a0e9..7e04e95de3 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1924,7 +1924,8 @@ class LEDDecoder(LEDPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0ed8e1cfd1..39f3c87b00 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -962,7 +962,8 @@ class MarianDecoder(MarianPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9c5f245b6c..333cd1bccc 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1006,7 +1006,8 @@ class MBartDecoder(MBartPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ecae05aab5..f1a41d51d5 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -970,7 +970,8 @@ class PegasusDecoder(PegasusPreTrainedModel): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 3213ff488e..1760617044 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -484,7 +484,8 @@ class RobertaEncoder(nn.Module): if use_cache: logger.warn( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." ) use_cache = False diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 9148157b7d..0173e8e5bc 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -526,8 +526,16 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions)