Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579)

* Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert)
* Sanity checks before tracing that make sure the model to trace is supported
* Adapted to PyTorch 1.9

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-10-05 14:19:47 +02:00
committed by GitHub
parent 46efc58024
commit d4e4efce68
11 changed files with 571 additions and 17 deletions

View File

@@ -232,6 +232,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True

View File

@@ -445,6 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@@ -92,6 +92,7 @@ class ModelTesterMixin:
all_model_classes = ()
all_generative_model_classes = ()
fx_ready_model_classes = ()
fx_dynamic_ready_model_classes = ()
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
@@ -607,14 +608,19 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
def test_torch_fx_dynamic_axes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
if not is_torch_fx_available():
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.fx_ready_model_classes:
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
@@ -640,12 +646,11 @@ class ModelTesterMixin:
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=[encoder_sequence_length, decoder_sequence_length],
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
@@ -679,8 +684,8 @@ class ModelTesterMixin:
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=sequence_length,
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=sequence_length if not dynamic_axes else -1,
num_choices=num_choices,
)
traced_output = traced_model(**filtered_inputs)

View File

@@ -210,6 +210,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
else None
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_pruning = True
test_torchscript = True
test_resize_embeddings = True

View File

@@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@@ -284,6 +284,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
# test_resize_embeddings = False
test_head_masking = False

View File

@@ -270,6 +270,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model