From 86d5fb0b360e68de46d40265e7c707fe68c8015b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 14 May 2021 20:57:30 +0200 Subject: [PATCH] Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 (#11475) Symbolic tracing feature for BERT, ELECTRA and T5 Co-authored-by: Michael Benayoun Co-authored-by: Stas Bekman Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/file_utils.py | 19 ++ src/transformers/modeling_fx_utils.py | 253 ++++++++++++++++++++++ src/transformers/models/t5/modeling_t5.py | 12 +- tests/test_modeling_bert.py | 1 + tests/test_modeling_common.py | 88 +++++++- tests/test_modeling_electra.py | 1 + tests/test_modeling_t5.py | 1 + 7 files changed, 371 insertions(+), 4 deletions(-) create mode 100644 src/transformers/modeling_fx_utils.py diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 2559ce1d7b..8b559a9e71 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -265,6 +265,15 @@ def is_torch_cuda_available(): return False +_torch_fx_available = False +if _torch_available: + _torch_fx_available = version.parse(_torch_version) >= version.parse("1.8") + + +def is_torch_fx_available(): + return _torch_fx_available + + def is_tf_available(): return _tf_available @@ -1597,11 +1606,21 @@ def tf_required(func): return wrapper +def is_torch_fx_proxy(x): + if is_torch_fx_available(): + import torch.fx + + return isinstance(x, torch.fx.Proxy) + return False + + def is_tensor(x): """ Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or :obj:`np.ndarray`. """ + if is_torch_fx_proxy(x): + return True if is_torch_available(): import torch diff --git a/src/transformers/modeling_fx_utils.py b/src/transformers/modeling_fx_utils.py new file mode 100644 index 0000000000..1bad3e4ec7 --- /dev/null +++ b/src/transformers/modeling_fx_utils.py @@ -0,0 +1,253 @@ +import dis +import inspect +from typing import List, Optional, Union + +import torch +from torch.fx import GraphModule, Node, Proxy, Tracer + +from . import PreTrainedModel + + +class HFProxy(Proxy): + """ + Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing + the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the + existing ones. + """ + + def __init__(self, node: Node, tracer: Optional[Tracer] = None): + super().__init__(node, tracer=tracer) + if hasattr(self, "tracer") and self.tracer is not None: + self.device = self.tracer.root.device + self.dtype = next(self.tracer.root.parameters()).dtype + + def dim(self): + return len(self.tracer.encoder_shape) + + def _shape(self, calling_frame): + module = calling_frame.f_locals.get("self", None) + is_decoder = hasattr(module, "is_decoder") and module.is_decoder + return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape) + + def size(self, dim=None): + frame = inspect.currentframe() + calling_frame = frame.f_back + + # self.size can be called through the shape property, in which case we need to get the outer + # frame, containing the meaningful information. + if calling_frame.f_code.co_name == "shape": + calling_frame = calling_frame.f_back + + instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti])) + code_context = inspect.getframeinfo(calling_frame).code_context[0].strip() + + shape = self._shape(calling_frame) + + if calling_frame.f_code.co_name == "transpose_for_scores": + # Provides the proper "x.size()" for: + # new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + shape = shape + [-1] + elif "context_layer" in calling_frame.f_locals: + # Provides the proper "context_layer.size()" for: + # new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + shape = shape + [-1, -1] + elif calling_frame.f_locals.get("do_cross_attention", False): + # Provides the proper shape for: + # query_length = present_key_value_state[0].shape[2] + # (modeling_t5.py) + shape = list(self.tracer.encoder_shape) + shape = shape[:1] + [-1] + shape[1:2] + elif "key_length" in code_context or "encoder_seq_length" in code_context: + shape = list(self.tracer.encoder_shape) + elif "lm_logits.size(-1)" in code_context: + shape = [self.tracer.root.config.vocab_size] + elif "start_positions" in code_context or "end_positions" in code_context: + # For question answering tasks. + shape = [1] + elif "num_choices" in code_context: + if self.tracer.num_choices <= 0: + raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.") + shape = shape[:1] + [self.tracer.num_choices] + shape[1:] + else: + # Default case: + # - If self.size is called for an unpacking, retrieves the corresponding unpacking + # instruction, and returns the shape padded as much as necessary to match the expected + # number of items. + # - If self.size is called outside of an unpacking context, simply return the shape. + is_unpack = False + + for inst in instructions: + if inst.opname == "UNPACK_SEQUENCE": + is_unpack = True + break + + if is_unpack and inst.argval >= 3: + shape += [self.tracer.root.config.hidden_size] + dummy_values = [1] * (inst.argval - 3) + shape += dummy_values + + if dim is not None: + return shape[dim] + + return tuple(shape) + + @property + def shape(self): + return self.size() + + def __bool__(self) -> bool: + frame = inspect.currentframe() + calling_frame = frame.f_back + code_context = inspect.getframeinfo(calling_frame).code_context[0].strip() + if calling_frame.f_code.co_name == "apply_chunking_to_forward": + # Returning True to every assertion in "apply_chuncking_to_forward" + return True + elif "assert" in code_context: + # Returning True to any assertion. + return True + elif calling_frame.f_code.co_name == "get_extended_attention_mask": + # Corresponding to: + # if causal_mask.shape[1] < attention_mask.shape[1]: + return calling_frame.f_back.f_locals["past_key_values"][0] is not None + raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.") + + def __setitem__(self, key, value): + pass + + def __contains__(self, key): + return False + + +class HFTracer(Tracer): + """ + Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it + uses the HFProxy instead of the regular PyTorch torch.fx.Proxy. + """ + + def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): + super().__init__() + encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length + decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1 + self.encoder_shape = [batch_size, encoder_sequence_length] + self.decoder_shape = ( + [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape) + ) + self.num_choices = num_choices + if self.num_choices > 0: + self.encoder_shape[0] *= self.num_choices + + self.prev_module = None + + def proxy(self, node: Node): + return HFProxy(node, self) + + def _insert_module_as_submodule(self, mod): + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + # First, retrieve the parent module. + if self.prev_module is None: + return None + parent_path = self.prev_module.rsplit(".", 1)[0] + parent_mod = None + for path, module in self.root.named_modules(): + if path == parent_path: + parent_mod = module + break + if parent_mod is None: + return None + + # If retrieving the parent module was possible, set the module not declared as a submodule + # as a parent module attribute. + path = None + for var_name, var_val in inspect.currentframe().f_back.f_locals.items(): + if mod is var_val: + setattr(parent_mod, var_name, mod) + path = f"{parent_path}.{var_name}" + break + + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if + ``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function + will return the string "foo.bar". + + Args: + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if hasattr(self, "submodule_paths") and self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + path = self._insert_module_as_submodule(mod) + if path is None: + raise NameError("module is not installed as a submodule") + self.prev_module = path + return path + + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + self.prev_module = n + return n + path = self._insert_module_as_submodule(mod) + if path is None: + raise NameError("module is not installed as a submodule") + self.prev_module = path + return path + + +def symbolic_trace( + model: PreTrainedModel, + input_names: Optional[List[str]] = None, + batch_size: int = 1, + sequence_length: Union[int, List[int]] = [128, 128], + num_choices: int = -1, +) -> GraphModule: + + """ + Performs symbolic tracing on the model. + + Args: + model (:obj:`PretrainedModel`): + The model to trace. + input_names (:obj:`List[str]`, `optional`): + The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead. + batch_size (:obj:`int`, `optional`, defaults to 1): + The batch size of the traced model inputs. + sequence_length (:obj:`int` or :obj:`List[int]]`): + The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence + lengths between the encoder and the decoder inputs, this must be :obj:`[encoder_sequence_length, + decoder_sequence_length]`. + num_choices (:obj:`int`, `optional`, defaults to -1): + The number of possible choices for a multiple choice task. + + Returns: + :obj:`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. + + Example:: + + from transformers.modeling_fx_utils import symbolic_trace + traced_model = symbolic_trace( + model, + input_names=["input_ids", "attention_mask", "token_type_ids"], + batch_size=1, + sequence_length=128, + ) + """ + if input_names is None: + input_names = model.dummy_inputs.keys() + + sig = inspect.signature(model.forward) + # TODO: how to handle the case of the "return_dict" parameter. + concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + + tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) + traced_graph = tracer.trace(model, concrete_args=concrete_args) + traced = torch.fx.GraphModule(model, traced_graph) + + return traced diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 4d570fec16..02b79d8901 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -32,6 +32,7 @@ from ...file_utils import ( DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torch_fx_proxy, replace_return_docstrings, ) from ...modeling_outputs import ( @@ -776,9 +777,14 @@ class T5PreTrainedModel(PreTrainedModel): ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" # shift inputs to the right - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index acd921ce8a..c87c97a543 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3ff21b1d5a..837e267bdd 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,7 +25,7 @@ from typing import List, Tuple from huggingface_hub import HfApi from requests.exceptions import HTTPError from transformers import is_torch_available, logging -from transformers.file_utils import WEIGHTS_NAME +from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( ENDPOINT_STAGING, @@ -64,6 +64,9 @@ if is_torch_available(): T5ForConditionalGeneration, ) +if is_torch_fx_available(): + from transformers.modeling_fx_utils import symbolic_trace + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -82,6 +85,7 @@ class ModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () + fx_ready_model_classes = () test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -565,6 +569,88 @@ class ModelTesterMixin: self.assertTrue(models_equal) + def test_torch_fx(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + self._create_and_check_torch_fx_tracing(config, inputs_dict) + + def test_torch_fx_output_loss(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) + + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available(): + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.fx_ready_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + input_ids = inputs["input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + labels = inputs.get("labels", None) + input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] + if labels is not None: + input_names.append("labels") + prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + + model_output = model(**prepared_inputs) + + batch_size = input_ids.shape[0] + encoder_sequence_length = input_ids.shape[1] + decoder_sequence_length = decoder_attention_mask.shape[1] + + traced_model = symbolic_trace( + model, + input_names, + batch_size=batch_size, + sequence_length=[encoder_sequence_length, decoder_sequence_length], + ) + + traced_output = traced_model(**prepared_inputs) + + else: + input_ids = inputs["input_ids"] + labels = inputs.get("labels", None) + input_names = ["input_ids", "attention_mask", "token_type_ids"] + if labels is not None: + input_names.append("labels") + prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + + model_output = model(**prepared_inputs) + + batch_size = input_ids.shape[0] + + if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + sequence_length = input_ids.shape[2] + num_choices = input_ids.shape[1] + else: + sequence_length = input_ids.shape[1] + num_choices = -1 + + traced_model = symbolic_trace( + model, + input_names, + batch_size=batch_size, + sequence_length=sequence_length, + num_choices=num_choices, + ) + traced_output = traced_model(**prepared_inputs) + + except RuntimeError: + self.fail("Couldn't trace module.") + + num_outputs = len(model_output) + outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs)) + self.assertTrue(outputs_are_close) + def test_headmasking(self): if not self.test_head_masking: return diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 366d8f0f90..8fcbb445a1 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 31b712b075..55b9c05682 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -488,6 +488,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True