From 2a18b709989d3cbb27ca7c5162ec2e89b067324f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 24 Dec 2020 17:47:36 +0530 Subject: [PATCH] enable cache by default (#9296) --- .../models/bert_generation/configuration_bert_generation.py | 5 +++++ .../models/bert_generation/modeling_bert_generation.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py index b3e915c3f6..54659f4394 100644 --- a/src/transformers/models/bert_generation/configuration_bert_generation.py +++ b/src/transformers/models/bert_generation/configuration_bert_generation.py @@ -61,6 +61,9 @@ class BertGenerationConfig(PretrainedConfig): `__. For more information on :obj:`"relative_key_query"`, please refer to `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `__. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. Examples:: @@ -95,6 +98,7 @@ class BertGenerationConfig(PretrainedConfig): eos_token_id=1, gradient_checkpointing=False, position_embedding_type="absolute", + use_cache=True, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -112,3 +116,4 @@ class BertGenerationConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type + self.use_cache = use_cache diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 37d4a5dd37..35e6d15f93 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -339,6 +339,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: