Generate: Deprecate returning legacy cache by default; Handle use_cache=False (#32863)

This commit is contained in:
Joao Gante
2024-08-22 20:01:52 +01:00
committed by GitHub
parent 09e6579d2d
commit a26de15139
4 changed files with 311 additions and 256 deletions

View File

@@ -194,6 +194,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@@ -207,6 +208,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
)
@@ -224,6 +226,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
@@ -239,6 +242,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
)
@@ -256,6 +260,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@@ -268,6 +273,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@@ -286,6 +292,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
@@ -299,6 +306,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@@ -317,6 +325,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@@ -329,6 +338,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@@ -348,6 +358,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@@ -361,6 +372,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
constraints=constraints,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@@ -378,6 +390,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
contrastive_search_kwargs = {
"penalty_alpha": 0.6,
@@ -396,6 +409,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
@@ -419,7 +433,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
@@ -430,6 +443,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -454,7 +468,6 @@ class GenerationTesterMixin:
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@@ -466,6 +479,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@@ -495,7 +509,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
@@ -507,6 +520,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -545,9 +559,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
@@ -560,6 +571,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
@@ -589,7 +601,6 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
@@ -602,6 +613,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@@ -676,9 +688,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@@ -692,6 +701,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -764,7 +774,6 @@ class GenerationTesterMixin:
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
@@ -778,6 +787,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
@@ -857,9 +867,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
# Sample constraints
@@ -882,6 +889,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -913,13 +921,12 @@ class GenerationTesterMixin:
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_generate = self._contrastive_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
@@ -940,7 +947,6 @@ class GenerationTesterMixin:
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@@ -953,6 +959,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@@ -978,7 +985,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
@@ -991,6 +997,7 @@ class GenerationTesterMixin:
low_memory=True,
max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
use_cache=True,
)
high_output = model.generate(
@@ -1000,6 +1007,7 @@ class GenerationTesterMixin:
low_memory=False,
max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@@ -1031,10 +1039,17 @@ class GenerationTesterMixin:
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
low_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
)
high_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
input_ids,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=False,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@@ -1079,7 +1094,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@@ -1098,6 +1112,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@@ -1150,7 +1165,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@@ -1169,6 +1183,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@@ -1196,12 +1211,6 @@ class GenerationTesterMixin:
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, input_ids, attention_mask = self._get_input_ids_and_config()
# Some models don't support the cache and returning past_key_values
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
self.skipTest("DoLa is not supported for encoder-decoder models")
@@ -1224,11 +1233,12 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
}
generation_kwargs.update({"dola_layers": "low"})
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
@@ -1261,7 +1271,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@@ -1284,6 +1293,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@@ -1566,7 +1576,6 @@ class GenerationTesterMixin:
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config.use_cache = True
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
@@ -1574,6 +1583,7 @@ class GenerationTesterMixin:
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.use_cache = True
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
@@ -1631,7 +1641,6 @@ class GenerationTesterMixin:
self.skipTest(reason="This model does not support the new cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
@@ -1640,6 +1649,7 @@ class GenerationTesterMixin:
"num_beams": num_beams,
"num_return_sequences": num_beams,
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
# Sets seed before calling `generate` for the case with do_sample=True
@@ -1701,7 +1711,6 @@ class GenerationTesterMixin:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20
@@ -1712,6 +1721,7 @@ class GenerationTesterMixin:
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
max_cache_len = seq_length + max_new_tokens
@@ -1740,7 +1750,6 @@ class GenerationTesterMixin:
self.skipTest(reason="This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@@ -1750,6 +1759,7 @@ class GenerationTesterMixin:
# careful with group size, should be divisor of model's hidden size
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@@ -1890,22 +1900,24 @@ class GenerationTesterMixin:
# Past Key Value States -- a few notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
# standard cache format (e.g.gptbigcode )
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
if use_cache and has_standard_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
if has_standard_cache:
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
elif use_cache is False:
self.assertTrue(output.past_key_values is None)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)