[XLNet] Fix mems behavior (#8567)
* fix mems in xlnet * fix use_mems * fix use_mem_len * fix use mems * clean docs * fix tf typo * make xlnet tf for generation work * fix tf test * refactor use cache * add use cache for missing models * correct use_cache in generate * correct use cache in tf generate * fix tf * correct getattr typo * make sylvain happy * change in docs as well * do not apply to cookie cutter statements * fix tf test * make pytorch model fully backward compatible
This commit is contained in:
committed by
GitHub
parent
369f1d77b4
commit
2a6fbe6a40
@@ -72,6 +72,8 @@ RAG_CONFIG_DOC = r"""
|
||||
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and
|
||||
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
"""
|
||||
|
||||
|
||||
@@ -107,6 +109,7 @@ class RagConfig(PretrainedConfig):
|
||||
exclude_bos_score=False,
|
||||
do_marginalize=False,
|
||||
output_retrieved=False,
|
||||
use_cache=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -156,6 +159,8 @@ class RagConfig(PretrainedConfig):
|
||||
|
||||
self.do_deduplication = do_deduplication
|
||||
|
||||
self.use_cache = use_cache
|
||||
|
||||
@classmethod
|
||||
def from_question_encoder_generator_configs(
|
||||
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user