enable cache by default (#9296)

This commit is contained in:
Suraj Patil
2020-12-24 17:47:36 +05:30
committed by GitHub
parent 6189ae9960
commit 2a18b70998
2 changed files with 10 additions and 0 deletions

View File

@@ -61,6 +61,9 @@ class BertGenerationConfig(PretrainedConfig):
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__. <https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
Examples:: Examples::
@@ -95,6 +98,7 @@ class BertGenerationConfig(PretrainedConfig):
eos_token_id=1, eos_token_id=1,
gradient_checkpointing=False, gradient_checkpointing=False,
position_embedding_type="absolute", position_embedding_type="absolute",
use_cache=True,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -112,3 +116,4 @@ class BertGenerationConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache

View File

@@ -339,6 +339,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None: