ViT and Swin symbolic tracing with torch.fx (#17182)
* Support tracing for ViT * Swin support * Fix copies * Fix type annotation issue * Removed unused import
This commit is contained in:
@@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
@@ -738,8 +738,7 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
@@ -756,12 +755,6 @@ class ModelTesterMixin:
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
rank = len(input_ids.shape)
|
||||
if rank not in [2, 3]:
|
||||
raise NotImplementedError(
|
||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user