[tests] reset logs in torch.compile test (#37894)

This commit is contained in:
Joao Gante
2025-04-30 16:04:28 +01:00
committed by GitHub
parent 1b222903c3
commit 8e8025b384

View File

@@ -2175,23 +2175,27 @@ class GenerationTesterMixin:
compiled_outputs = [] compiled_outputs = []
# Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails. # Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails.
torch._logging.set_logs(recompiles_verbose=True) # Try/Finally is used to ensure that the log options are reset even if an error is raised.
logger = logging.get_logger("torch._dynamo.guards") try:
with CaptureLogger(logger) as cl: torch._logging.set_logs(recompiles_verbose=True)
for model_inputs in model_input_sets: logger = logging.get_logger("torch._dynamo.guards")
# with torch.compiler.set_stance("fail_on_recompile"): with CaptureLogger(logger) as cl:
gen_out = model.generate(**model_inputs, **generation_kwargs) for model_inputs in model_input_sets:
compiled_outputs.append(gen_out) # with torch.compiler.set_stance("fail_on_recompile"):
# sanity checks gen_out = model.generate(**model_inputs, **generation_kwargs)
decoder_cache = ( compiled_outputs.append(gen_out)
gen_out.past_key_values.self_attention_cache # sanity checks
if config.is_encoder_decoder decoder_cache = (
else gen_out.past_key_values gen_out.past_key_values.self_attention_cache
) if config.is_encoder_decoder
self.assertFalse(isinstance(decoder_cache, DynamicCache)) else gen_out.past_key_values
self.assertTrue(decoder_cache.is_compileable) )
# our auto compile should have been called self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call")) self.assertTrue(decoder_cache.is_compileable)
# our auto compile should have been called
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call"))
finally:
torch._logging.set_logs()
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out): if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out):
raise RuntimeError( raise RuntimeError(