From 6f127d3f81c8cf434aa809a4d8ec76b3a6372060 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 10 Jan 2025 10:46:03 +0100 Subject: [PATCH] Skip `torchscript` tests if a cache object is in model's outputs (#35596) * fix 1 * fix 1 * comment --------- Co-authored-by: ydshieh --- tests/test_modeling_common.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c29a15efd3..bb2b17f8e5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -122,7 +122,7 @@ if is_torch_available(): from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding - from transformers.cache_utils import DynamicCache + from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -1109,7 +1109,14 @@ class ModelTesterMixin: attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + # `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes + # several models to use `Cache` by default instead of the legacy cache (tuple), and + # their `torchscript` tests are failing. We won't support them anyway, but we still want to keep + # the tests for encoder models like `BERT`. So we skip the checks if the model's output contains + # a `Cache` object. + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) ) @@ -1117,14 +1124,18 @@ class ModelTesterMixin: input_ids = inputs["input_ids"] bbox = inputs["bbox"] image = inputs["image"].tensor - model(input_ids, bbox, image) + outputs = model(input_ids, bbox, image) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (input_ids, bbox, image), check_trace=False ) # when traced model is checked, an error is produced due to name mangling elif "bbox" in inputs: # Bros requires additional inputs (bbox) input_ids = inputs["input_ids"] bbox = inputs["bbox"] - model(input_ids, bbox) + outputs = model(input_ids, bbox) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (input_ids, bbox), check_trace=False ) # when traced model is checked, an error is produced due to name mangling @@ -1134,7 +1145,9 @@ class ModelTesterMixin: pixel_values = inputs["pixel_values"] prompt_pixel_values = inputs["prompt_pixel_values"] prompt_masks = inputs["prompt_masks"] - model(pixel_values, prompt_pixel_values, prompt_masks) + outputs = model(pixel_values, prompt_pixel_values, prompt_masks) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False ) # when traced model is checked, an error is produced due to name mangling @@ -1149,11 +1162,15 @@ class ModelTesterMixin: else: self.skipTest(reason="testing SDPA without attention_mask is not supported") - model(main_input, attention_mask=inputs["attention_mask"]) + outputs = model(main_input, attention_mask=inputs["attention_mask"]) + if any(isinstance(x, Cache) for x in outputs): + continue # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) else: - model(main_input) + outputs = model(main_input) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace(model, (main_input,)) except RuntimeError: self.fail("Couldn't trace module.")