Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579)
* Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert) * Sanity checks before tracing that make sure the model to trace is supported * Adapted to PyTorch 1.9 Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
@@ -232,6 +232,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
|
||||
@@ -445,6 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
@@ -92,6 +92,7 @@ class ModelTesterMixin:
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
fx_ready_model_classes = ()
|
||||
fx_dynamic_ready_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@@ -607,14 +608,19 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
def test_torch_fx_dynamic_axes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
|
||||
if not is_torch_fx_available():
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.fx_ready_model_classes:
|
||||
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
|
||||
for model_class in model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -640,12 +646,11 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
|
||||
)
|
||||
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
@@ -679,8 +684,8 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=sequence_length if not dynamic_axes else -1,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
@@ -210,6 +210,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else None
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
@@ -284,6 +284,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
# test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
@@ -270,6 +270,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
Reference in New Issue
Block a user