[XLNet] Fix mems behavior (#8567)

* fix mems in xlnet

* fix use_mems

* fix use_mem_len

* fix use mems

* clean docs

* fix tf typo

* make xlnet tf for generation work

* fix tf test

* refactor use cache

* add use cache for missing models

* correct use_cache in generate

* correct use cache in tf generate

* fix tf

* correct getattr typo

* make sylvain happy

* change in docs as well

* do not apply to cookie cutter statements

* fix tf test

* make pytorch model fully backward compatible
This commit is contained in:
Patrick von Platen
2020-11-25 22:54:59 +01:00
committed by GitHub
parent 369f1d77b4
commit 2a6fbe6a40
47 changed files with 259 additions and 134 deletions

View File

@@ -72,6 +72,8 @@ RAG_CONFIG_DOC = r"""
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and
:obj:`context_attention_mask` are returned. See returned tensors for more detail.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
@@ -107,6 +109,7 @@ class RagConfig(PretrainedConfig):
exclude_bos_score=False,
do_marginalize=False,
output_retrieved=False,
use_cache=True,
**kwargs
):
super().__init__(
@@ -156,6 +159,8 @@ class RagConfig(PretrainedConfig):
self.do_deduplication = do_deduplication
self.use_cache = use_cache
@classmethod
def from_question_encoder_generator_configs(
cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs