[FX] Symbolic trace for Bloom (#18356)

* Bloom model can now be traced

* Bloom traced model can be torch scripted and serialized

* Bloom can be traced with variable keyword arguments

* Enable XLNet support

* Disable XLNet for now
This commit is contained in:
Michael Benayoun
2022-07-29 16:12:27 +02:00
committed by GitHub
parent 1763770bd9
commit 4e2f4a92dd
3 changed files with 19 additions and 21 deletions

View File

@@ -320,7 +320,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False
fx_compatible = True
test_missing_keys = False
test_pruning = False
test_torchscript = True # torch.autograd functions seems to be not supported