FX tracing improvement (#14321)
* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
This commit is contained in:
@@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -116,8 +116,7 @@ class ModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
fx_ready_model_classes = ()
|
||||
fx_dynamic_ready_model_classes = ()
|
||||
fx_compatible = False
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@@ -666,19 +665,14 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
def test_torch_fx_dynamic_axes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
|
||||
if not is_torch_fx_available():
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
|
||||
for model_class in model_classes:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -687,8 +681,6 @@ class ModelTesterMixin:
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
@@ -697,17 +689,7 @@ class ModelTesterMixin:
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
encoder_sequence_length = input_ids.shape[1]
|
||||
decoder_sequence_length = decoder_attention_mask.shape[1]
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
@@ -729,23 +711,12 @@ class ModelTesterMixin:
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
rank = len(input_ids.shape)
|
||||
if rank == 2:
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
num_choices = -1
|
||||
elif rank == 3:
|
||||
batch_size, num_choices, sequence_length = input_ids.shape
|
||||
else:
|
||||
if rank not in [2, 3]:
|
||||
raise NotImplementedError(
|
||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=sequence_length if not dynamic_axes else -1,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
|
||||
@@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else None
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
||||
@@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
test_missing_keys = False
|
||||
test_model_parallel = True
|
||||
|
||||
|
||||
@@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_model_parallel = False
|
||||
|
||||
@@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
|
||||
@@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
fx_compatible = True
|
||||
# test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
|
||||
@@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
||||
@@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_compatible = True
|
||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
|
||||
@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
all_generative_model_classes = (
|
||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
|
||||
test_pruning = False
|
||||
|
||||
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
||||
|
||||
Reference in New Issue
Block a user