Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579)

* Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert)
* Sanity checks before tracing that make sure the model to trace is supported
* Adapted to PyTorch 1.9

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-10-05 14:19:47 +02:00
committed by GitHub
parent 46efc58024
commit d4e4efce68
11 changed files with 571 additions and 17 deletions

View File

@@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model