Skip torchscript tests if a cache object is in model's outputs (#35596)
* fix 1 * fix 1 * comment --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -122,7 +122,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
@@ -1109,7 +1109,14 @@ class ModelTesterMixin:
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
# `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes
|
||||
# several models to use `Cache` by default instead of the legacy cache (tuple), and
|
||||
# their `torchscript` tests are failing. We won't support them anyway, but we still want to keep
|
||||
# the tests for encoder models like `BERT`. So we skip the checks if the model's output contains
|
||||
# a `Cache` object.
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
@@ -1117,14 +1124,18 @@ class ModelTesterMixin:
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
model(input_ids, bbox, image)
|
||||
outputs = model(input_ids, bbox, image)
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
model(input_ids, bbox)
|
||||
outputs = model(input_ids, bbox)
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
@@ -1134,7 +1145,9 @@ class ModelTesterMixin:
|
||||
pixel_values = inputs["pixel_values"]
|
||||
prompt_pixel_values = inputs["prompt_pixel_values"]
|
||||
prompt_masks = inputs["prompt_masks"]
|
||||
model(pixel_values, prompt_pixel_values, prompt_masks)
|
||||
outputs = model(pixel_values, prompt_pixel_values, prompt_masks)
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
traced_model = torch.jit.trace(
|
||||
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
@@ -1149,11 +1162,15 @@ class ModelTesterMixin:
|
||||
else:
|
||||
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
||||
|
||||
model(main_input, attention_mask=inputs["attention_mask"])
|
||||
outputs = model(main_input, attention_mask=inputs["attention_mask"])
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
||||
else:
|
||||
model(main_input)
|
||||
outputs = model(main_input)
|
||||
if any(isinstance(x, Cache) for x in outputs):
|
||||
continue
|
||||
traced_model = torch.jit.trace(model, (main_input,))
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
Reference in New Issue
Block a user