Llama: make slow tests green 🟢 (#33138)

This commit is contained in:
Joao Gante
2024-08-27 14:44:42 +01:00
committed by GitHub
parent 9956c2bc98
commit c6b23fda65
31 changed files with 39 additions and 180 deletions

View File

@@ -1803,6 +1803,8 @@ class GenerationTesterMixin:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
input_ids = inputs_dict["input_ids"].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
@@ -1815,22 +1817,14 @@ class GenerationTesterMixin:
}
for model_inputs in input_ids_sets:
# dynamic cache
# eager dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
# eager static cache
torch.compiler.reset()
model.generation_config.cache_implementation = "static"
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
# end-to-end compiled dynamic cache
torch.compiler.reset()
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())