Contrastive Search peak memory reduction (#24120)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Benjamin Badger
2023-07-20 13:46:53 -04:00
committed by GitHub
parent aa1b09c5d1
commit caf5e369fc
3 changed files with 147 additions and 31 deletions

View File

@@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=True,
max_length=max_length,
attention_mask=attention_mask,
)
high_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_length=max_length,
attention_mask=attention_mask,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
return
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.