Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 (#11475)
Symbolic tracing feature for BERT, ELECTRA and T5 Co-authored-by: Michael Benayoun <michael@huggingface.co> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
@@ -25,7 +25,7 @@ from typing import List, Tuple
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
@@ -64,6 +64,9 @@ if is_torch_available():
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.modeling_fx_utils import symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -82,6 +85,7 @@ class ModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
fx_ready_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@@ -565,6 +569,88 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_torch_fx(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict)
|
||||
|
||||
def test_torch_fx_output_loss(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available():
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.fx_ready_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
encoder_sequence_length = input_ids.shape[1]
|
||||
decoder_sequence_length = decoder_attention_mask.shape[1]
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
)
|
||||
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
sequence_length = input_ids.shape[2]
|
||||
num_choices = input_ids.shape[1]
|
||||
else:
|
||||
sequence_length = input_ids.shape[1]
|
||||
num_choices = -1
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
num_outputs = len(model_output)
|
||||
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
|
||||
self.assertTrue(outputs_are_close)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
||||
@@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
||||
@@ -488,6 +488,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
|
||||
Reference in New Issue
Block a user