Generate: Deprecate returning legacy cache by default; Handle use_cache=False (#32863)
This commit is contained in:
@@ -194,6 +194,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -207,6 +208,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -224,6 +226,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
@@ -239,6 +242,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -256,6 +260,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -268,6 +273,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -286,6 +292,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
@@ -299,6 +306,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -317,6 +325,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -329,6 +338,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -348,6 +358,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -361,6 +372,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
constraints=constraints,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -378,6 +390,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
contrastive_search_kwargs = {
|
||||
"penalty_alpha": 0.6,
|
||||
@@ -396,6 +409,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
@@ -419,7 +433,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
@@ -430,6 +443,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -454,7 +468,6 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
@@ -466,6 +479,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -495,7 +509,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
@@ -507,6 +520,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -545,9 +559,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(
|
||||
@@ -560,6 +571,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -589,7 +601,6 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._beam_search_generate(
|
||||
@@ -602,6 +613,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -676,9 +688,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
@@ -692,6 +701,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -764,7 +774,6 @@ class GenerationTesterMixin:
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@@ -778,6 +787,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -857,9 +867,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Sample constraints
|
||||
@@ -882,6 +889,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -913,13 +921,12 @@ class GenerationTesterMixin:
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test old generation output for backwards compatibility
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._contrastive_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -940,7 +947,6 @@ class GenerationTesterMixin:
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -953,6 +959,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -978,7 +985,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
@@ -991,6 +997,7 @@ class GenerationTesterMixin:
|
||||
low_memory=True,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
@@ -1000,6 +1007,7 @@ class GenerationTesterMixin:
|
||||
low_memory=False,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@@ -1031,10 +1039,17 @@ class GenerationTesterMixin:
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
||||
low_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
||||
input_ids,
|
||||
max_new_tokens=8,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
low_memory=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@@ -1079,7 +1094,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@@ -1098,6 +1112,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@@ -1150,7 +1165,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@@ -1169,6 +1183,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@@ -1196,12 +1211,6 @@ class GenerationTesterMixin:
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# Some models don't support the cache and returning past_key_values
|
||||
if not hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
else:
|
||||
config.use_cache = True
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
@@ -1224,11 +1233,12 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
|
||||
}
|
||||
generation_kwargs.update({"dola_layers": "low"})
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
||||
|
||||
def test_assisted_decoding_sample(self):
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
@@ -1261,7 +1271,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@@ -1284,6 +1293,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@@ -1566,7 +1576,6 @@ class GenerationTesterMixin:
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
config.use_cache = True
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
@@ -1574,6 +1583,7 @@ class GenerationTesterMixin:
|
||||
model.eval()
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
@@ -1631,7 +1641,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
@@ -1640,6 +1649,7 @@ class GenerationTesterMixin:
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_beams,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
@@ -1701,7 +1711,6 @@ class GenerationTesterMixin:
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
batch_size, seq_length = input_ids.shape
|
||||
max_new_tokens = 20
|
||||
@@ -1712,6 +1721,7 @@ class GenerationTesterMixin:
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
@@ -1740,7 +1750,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="This model does not support the quantized cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -1750,6 +1759,7 @@ class GenerationTesterMixin:
|
||||
# careful with group size, should be divisor of model's hidden size
|
||||
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@@ -1890,22 +1900,24 @@ class GenerationTesterMixin:
|
||||
|
||||
# Past Key Value States -- a few notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||
# standard cache format (e.g.gptbigcode )
|
||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
if use_cache and has_standard_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
if has_standard_cache:
|
||||
if use_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
elif use_cache is False:
|
||||
self.assertTrue(output.past_key_values is None)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
|
||||
Reference in New Issue
Block a user