[tests] reset logs in torch.compile test (#37894)
This commit is contained in:
@@ -2175,23 +2175,27 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
compiled_outputs = []
|
compiled_outputs = []
|
||||||
# Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails.
|
# Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails.
|
||||||
torch._logging.set_logs(recompiles_verbose=True)
|
# Try/Finally is used to ensure that the log options are reset even if an error is raised.
|
||||||
logger = logging.get_logger("torch._dynamo.guards")
|
try:
|
||||||
with CaptureLogger(logger) as cl:
|
torch._logging.set_logs(recompiles_verbose=True)
|
||||||
for model_inputs in model_input_sets:
|
logger = logging.get_logger("torch._dynamo.guards")
|
||||||
# with torch.compiler.set_stance("fail_on_recompile"):
|
with CaptureLogger(logger) as cl:
|
||||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
for model_inputs in model_input_sets:
|
||||||
compiled_outputs.append(gen_out)
|
# with torch.compiler.set_stance("fail_on_recompile"):
|
||||||
# sanity checks
|
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||||
decoder_cache = (
|
compiled_outputs.append(gen_out)
|
||||||
gen_out.past_key_values.self_attention_cache
|
# sanity checks
|
||||||
if config.is_encoder_decoder
|
decoder_cache = (
|
||||||
else gen_out.past_key_values
|
gen_out.past_key_values.self_attention_cache
|
||||||
)
|
if config.is_encoder_decoder
|
||||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
else gen_out.past_key_values
|
||||||
self.assertTrue(decoder_cache.is_compileable)
|
)
|
||||||
# our auto compile should have been called
|
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||||
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call"))
|
self.assertTrue(decoder_cache.is_compileable)
|
||||||
|
# our auto compile should have been called
|
||||||
|
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call"))
|
||||||
|
finally:
|
||||||
|
torch._logging.set_logs()
|
||||||
|
|
||||||
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out):
|
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
Reference in New Issue
Block a user