Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 (#11475)

Symbolic tracing feature for BERT, ELECTRA and T5

Co-authored-by: Michael Benayoun <michael@huggingface.co>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Michael Benayoun
2021-05-14 20:57:30 +02:00
committed by GitHub
parent 94a2348706
commit 86d5fb0b36
7 changed files with 371 additions and 4 deletions

View File

@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model