ViT and Swin symbolic tracing with torch.fx (#17182)

* Support tracing for ViT

* Swin support

* Fix copies

* Fix type annotation issue

* Removed unused import
This commit is contained in:
Michael Benayoun
2022-05-12 10:42:27 +02:00
committed by GitHub
parent 1a688709b3
commit 8c7481f35c
11 changed files with 70 additions and 35 deletions

View File

@@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False

View File

@@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False

View File

@@ -738,8 +738,7 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
@@ -756,12 +755,6 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs)
rank = len(input_ids.shape)
if rank not in [2, 3]:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)