Skip torchscript tests if a cache object is in model's outputs (#35596)

* fix 1

* fix 1

* comment

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-10 10:46:03 +01:00
committed by GitHub
parent 6b73ee8905
commit 6f127d3f81

View File

@@ -122,7 +122,7 @@ if is_torch_available():
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage
@@ -1109,7 +1109,14 @@ class ModelTesterMixin:
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
# `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes
# several models to use `Cache` by default instead of the legacy cache (tuple), and
# their `torchscript` tests are failing. We won't support them anyway, but we still want to keep
# the tests for encoder models like `BERT`. So we skip the checks if the model's output contains
# a `Cache` object.
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
@@ -1117,14 +1124,18 @@ class ModelTesterMixin:
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
outputs = model(input_ids, bbox, image)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
outputs = model(input_ids, bbox)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
@@ -1134,7 +1145,9 @@ class ModelTesterMixin:
pixel_values = inputs["pixel_values"]
prompt_pixel_values = inputs["prompt_pixel_values"]
prompt_masks = inputs["prompt_masks"]
model(pixel_values, prompt_pixel_values, prompt_masks)
outputs = model(pixel_values, prompt_pixel_values, prompt_masks)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
@@ -1149,11 +1162,15 @@ class ModelTesterMixin:
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")
model(main_input, attention_mask=inputs["attention_mask"])
outputs = model(main_input, attention_mask=inputs["attention_mask"])
if any(isinstance(x, Cache) for x in outputs):
continue
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input)
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")