[XLNet] Fix mems behavior (#8567)
* fix mems in xlnet * fix use_mems * fix use_mem_len * fix use mems * clean docs * fix tf typo * make xlnet tf for generation work * fix tf test * refactor use cache * add use cache for missing models * correct use_cache in generate * correct use cache in tf generate * fix tf * correct getattr typo * make sylvain happy * change in docs as well * do not apply to cookie cutter statements * fix tf test * make pytorch model fully backward compatible
This commit is contained in:
committed by
GitHub
parent
369f1d77b4
commit
2a6fbe6a40
@@ -38,6 +38,7 @@ class TFGenerationMixin:
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
use_cache = getattr(self.config, "use_cache", False)
|
||||
if len(outputs) <= 1 or use_cache is False:
|
||||
return False
|
||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
||||
@@ -194,7 +195,6 @@ class TFGenerationMixin:
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
@@ -224,7 +224,6 @@ class TFGenerationMixin:
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictly positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
|
||||
Reference in New Issue
Block a user