enable cache by default (#9296)
This commit is contained in:
@@ -61,6 +61,9 @@ class BertGenerationConfig(PretrainedConfig):
|
|||||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||||
<https://arxiv.org/abs/2009.13658>`__.
|
<https://arxiv.org/abs/2009.13658>`__.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if ``config.is_decoder=True``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -95,6 +98,7 @@ class BertGenerationConfig(PretrainedConfig):
|
|||||||
eos_token_id=1,
|
eos_token_id=1,
|
||||||
gradient_checkpointing=False,
|
gradient_checkpointing=False,
|
||||||
position_embedding_type="absolute",
|
position_embedding_type="absolute",
|
||||||
|
use_cache=True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
@@ -112,3 +116,4 @@ class BertGenerationConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
|
self.use_cache = use_cache
|
||||||
|
|||||||
@@ -339,6 +339,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user