Skip torchscript tests if a cache object is in model's outputs (#35596)
* fix 1 * fix 1 * comment --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -122,7 +122,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_utils import load_state_dict, no_init_weights
|
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||||
from transformers.pytorch_utils import id_tensor_storage
|
from transformers.pytorch_utils import id_tensor_storage
|
||||||
|
|
||||||
@@ -1109,7 +1109,14 @@ class ModelTesterMixin:
|
|||||||
attention_mask = inputs["attention_mask"]
|
attention_mask = inputs["attention_mask"]
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
# `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes
|
||||||
|
# several models to use `Cache` by default instead of the legacy cache (tuple), and
|
||||||
|
# their `torchscript` tests are failing. We won't support them anyway, but we still want to keep
|
||||||
|
# the tests for encoder models like `BERT`. So we skip the checks if the model's output contains
|
||||||
|
# a `Cache` object.
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
)
|
)
|
||||||
@@ -1117,14 +1124,18 @@ class ModelTesterMixin:
|
|||||||
input_ids = inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
bbox = inputs["bbox"]
|
bbox = inputs["bbox"]
|
||||||
image = inputs["image"].tensor
|
image = inputs["image"].tensor
|
||||||
model(input_ids, bbox, image)
|
outputs = model(input_ids, bbox, image)
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (input_ids, bbox, image), check_trace=False
|
model, (input_ids, bbox, image), check_trace=False
|
||||||
) # when traced model is checked, an error is produced due to name mangling
|
) # when traced model is checked, an error is produced due to name mangling
|
||||||
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
|
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
|
||||||
input_ids = inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
bbox = inputs["bbox"]
|
bbox = inputs["bbox"]
|
||||||
model(input_ids, bbox)
|
outputs = model(input_ids, bbox)
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (input_ids, bbox), check_trace=False
|
model, (input_ids, bbox), check_trace=False
|
||||||
) # when traced model is checked, an error is produced due to name mangling
|
) # when traced model is checked, an error is produced due to name mangling
|
||||||
@@ -1134,7 +1145,9 @@ class ModelTesterMixin:
|
|||||||
pixel_values = inputs["pixel_values"]
|
pixel_values = inputs["pixel_values"]
|
||||||
prompt_pixel_values = inputs["prompt_pixel_values"]
|
prompt_pixel_values = inputs["prompt_pixel_values"]
|
||||||
prompt_masks = inputs["prompt_masks"]
|
prompt_masks = inputs["prompt_masks"]
|
||||||
model(pixel_values, prompt_pixel_values, prompt_masks)
|
outputs = model(pixel_values, prompt_pixel_values, prompt_masks)
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
|
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
|
||||||
) # when traced model is checked, an error is produced due to name mangling
|
) # when traced model is checked, an error is produced due to name mangling
|
||||||
@@ -1149,11 +1162,15 @@ class ModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
||||||
|
|
||||||
model(main_input, attention_mask=inputs["attention_mask"])
|
outputs = model(main_input, attention_mask=inputs["attention_mask"])
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||||
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
||||||
else:
|
else:
|
||||||
model(main_input)
|
outputs = model(main_input)
|
||||||
|
if any(isinstance(x, Cache) for x in outputs):
|
||||||
|
continue
|
||||||
traced_model = torch.jit.trace(model, (main_input,))
|
traced_model = torch.jit.trace(model, (main_input,))
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.fail("Couldn't trace module.")
|
self.fail("Couldn't trace module.")
|
||||||
|
|||||||
Reference in New Issue
Block a user