Contrastive Search peak memory reduction (#24120)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
|
||||
for output in (output_contrastive, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
|
||||
):
|
||||
return
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(
|
||||
input_ids,
|
||||
top_k=4,
|
||||
penalty_alpha=0.6,
|
||||
low_memory=True,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
input_ids,
|
||||
top_k=4,
|
||||
penalty_alpha=0.6,
|
||||
low_memory=False,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
return
|
||||
|
||||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
|
||||
Reference in New Issue
Block a user