Llama: make slow tests green 🟢 (#33138)
This commit is contained in:
@@ -1803,6 +1803,8 @@ class GenerationTesterMixin:
|
||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
||||
|
||||
input_ids = inputs_dict["input_ids"].to(torch_device)
|
||||
# creates two sets of *different* inputs with the same shape
|
||||
half_batch_size = input_ids.shape[0] // 2
|
||||
@@ -1815,22 +1817,14 @@ class GenerationTesterMixin:
|
||||
}
|
||||
|
||||
for model_inputs in input_ids_sets:
|
||||
# dynamic cache
|
||||
# eager dynamic cache
|
||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
||||
|
||||
# eager static cache
|
||||
torch.compiler.reset()
|
||||
model.generation_config.cache_implementation = "static"
|
||||
output_static = model.generate(model_inputs, **generation_kwargs)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
|
||||
|
||||
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
|
||||
# initialize the cache in full compiled mode)
|
||||
model._cache = None
|
||||
# end-to-end compiled dynamic cache
|
||||
torch.compiler.reset()
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user