A cleaner and more scalable implementation of symbolic tracing (#11763)

Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures:
- ALBERT
- DistilBERT
- MobileBERT
- MegatronBERT
- GPT2
- GPT Neo

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun
2021-05-20 18:02:29 +02:00
committed by GitHub
parent 469384a777
commit f4a0d6ff86
8 changed files with 260 additions and 114 deletions

View File

@@ -399,6 +399,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_missing_keys = False
test_model_parallel = True