From 8e8025b384f11e786b617304c724d0d28a308552 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 30 Apr 2025 16:04:28 +0100 Subject: [PATCH] [tests] reset logs in `torch.compile` test (#37894) --- tests/generation/test_utils.py | 38 +++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 64f728224a..46a1b90001 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2175,23 +2175,27 @@ class GenerationTesterMixin: compiled_outputs = [] # Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails. - torch._logging.set_logs(recompiles_verbose=True) - logger = logging.get_logger("torch._dynamo.guards") - with CaptureLogger(logger) as cl: - for model_inputs in model_input_sets: - # with torch.compiler.set_stance("fail_on_recompile"): - gen_out = model.generate(**model_inputs, **generation_kwargs) - compiled_outputs.append(gen_out) - # sanity checks - decoder_cache = ( - gen_out.past_key_values.self_attention_cache - if config.is_encoder_decoder - else gen_out.past_key_values - ) - self.assertFalse(isinstance(decoder_cache, DynamicCache)) - self.assertTrue(decoder_cache.is_compileable) - # our auto compile should have been called - self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call")) + # Try/Finally is used to ensure that the log options are reset even if an error is raised. + try: + torch._logging.set_logs(recompiles_verbose=True) + logger = logging.get_logger("torch._dynamo.guards") + with CaptureLogger(logger) as cl: + for model_inputs in model_input_sets: + # with torch.compiler.set_stance("fail_on_recompile"): + gen_out = model.generate(**model_inputs, **generation_kwargs) + compiled_outputs.append(gen_out) + # sanity checks + decoder_cache = ( + gen_out.past_key_values.self_attention_cache + if config.is_encoder_decoder + else gen_out.past_key_values + ) + self.assertFalse(isinstance(decoder_cache, DynamicCache)) + self.assertTrue(decoder_cache.is_compileable) + # our auto compile should have been called + self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call")) + finally: + torch._logging.set_logs() if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out): raise RuntimeError(