FX tracing improvement (#14321)

* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
This commit is contained in:
Michael Benayoun
2022-02-07 22:25:33 +01:00
committed by GitHub
parent 552f8d3091
commit 0fe17f375a
31 changed files with 349 additions and 689 deletions

View File

@@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
fx_compatible = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):