From f4a0d6ff867e8a82a33d7a653e7d45372a463271 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 May 2021 18:02:29 +0200 Subject: [PATCH] A cleaner and more scalable implementation of symbolic tracing (#11763) Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures: - ALBERT - DistilBERT - MobileBERT - MegatronBERT - GPT2 - GPT Neo Co-authored-by: Michael Benayoun --- src/transformers/modeling_fx_utils.py | 307 ++++++++++++++++++-------- tests/test_modeling_albert.py | 1 + tests/test_modeling_common.py | 60 +++-- tests/test_modeling_distilbert.py | 1 + tests/test_modeling_gpt2.py | 1 + tests/test_modeling_gpt_neo.py | 1 + tests/test_modeling_megatron_bert.py | 2 +- tests/test_modeling_mobilebert.py | 1 + 8 files changed, 260 insertions(+), 114 deletions(-) diff --git a/src/transformers/modeling_fx_utils.py b/src/transformers/modeling_fx_utils.py index e9cdf00ce8..6c43a56bfb 100644 --- a/src/transformers/modeling_fx_utils.py +++ b/src/transformers/modeling_fx_utils.py @@ -1,11 +1,31 @@ -import dis +import copy +import functools import inspect -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -from torch.fx import GraphModule, Node, Proxy, Tracer +from torch.fx import Graph, GraphModule, Node, Proxy, Tracer +from torch.fx.node import Argument -from . import PreTrainedModel +from . import ( + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + GPT2DoubleHeadsModel, + PreTrainedModel, + logging, +) +from .models.auto import get_values + + +logger = logging.get_logger(__name__) class HFProxy(Proxy): @@ -21,98 +41,10 @@ class HFProxy(Proxy): self.device = self.tracer.root.device self.dtype = next(self.tracer.root.parameters()).dtype - def dim(self): - return len(self.tracer.encoder_shape) - - def _shape(self, calling_frame): - module = calling_frame.f_locals.get("self", None) - is_decoder = hasattr(module, "is_decoder") and module.is_decoder - return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape) - - def size(self, dim=None): - frame = inspect.currentframe() - calling_frame = frame.f_back - - # self.size can be called through the shape property, in which case we need to get the outer - # frame, containing the meaningful information. - if calling_frame.f_code.co_name == "shape": - calling_frame = calling_frame.f_back - - instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti])) - code_context = inspect.getframeinfo(calling_frame).code_context[0].strip() - - shape = self._shape(calling_frame) - - if calling_frame.f_code.co_name == "transpose_for_scores": - # Provides the proper "x.size()" for: - # new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - shape = shape + [-1] - elif "context_layer" in calling_frame.f_locals: - # Provides the proper "context_layer.size()" for: - # new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - shape = shape + [-1, -1] - elif calling_frame.f_locals.get("do_cross_attention", False): - # Provides the proper shape for: - # query_length = present_key_value_state[0].shape[2] - # (modeling_t5.py) - shape = list(self.tracer.encoder_shape) - shape = shape[:1] + [-1] + shape[1:2] - elif "key_length" in code_context or "encoder_seq_length" in code_context: - shape = list(self.tracer.encoder_shape) - elif "lm_logits.size(-1)" in code_context: - shape = [self.tracer.root.config.vocab_size] - elif "start_positions" in code_context or "end_positions" in code_context: - # For question answering tasks. - shape = [1] - elif "num_choices" in code_context: - if self.tracer.num_choices <= 0: - raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.") - shape = shape[:1] + [self.tracer.num_choices] + shape[1:] - elif "hidden_states.s" in code_context: - shape = shape + [self.tracer.root.config.hidden_size] - else: - # Default case: - # - If self.size is called for an unpacking, retrieves the corresponding unpacking - # instruction, and returns the shape padded as much as necessary to match the expected - # number of items. - # - If self.size is called outside of an unpacking context, simply return the shape. - is_unpack = False - - for inst in instructions: - if inst.opname == "UNPACK_SEQUENCE": - is_unpack = True - break - - if is_unpack and inst.argval >= 3: - shape += [self.tracer.root.config.hidden_size] - dummy_values = [1] * (inst.argval - 3) - shape += dummy_values - - if dim is not None: - return shape[dim] - - return tuple(shape) - @property def shape(self): return self.size() - def __bool__(self) -> bool: - frame = inspect.currentframe() - calling_frame = frame.f_back - code_context = inspect.getframeinfo(calling_frame).code_context[0].strip() - if calling_frame.f_code.co_name == "apply_chunking_to_forward": - # Returning True to every assertion in "apply_chuncking_to_forward" - return True - elif "assert" in code_context: - # Returning True to any assertion. - return True - elif calling_frame.f_code.co_name == "get_extended_attention_mask": - # Corresponding to: - # if causal_mask.shape[1] < attention_mask.shape[1]: - return calling_frame.f_back.f_locals["past_key_values"][0] is not None - raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.") - def __setitem__(self, key, value): pass @@ -120,28 +52,203 @@ class HFProxy(Proxy): return False +def _wrap_method_for_model_recording(model, method_name, cache_name): + """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" + method = getattr(torch.Tensor, method_name) + + @functools.wraps(method) + def wrapped(*args, **kwargs): + if not hasattr(model, cache_name): + setattr(model, cache_name, []) + cache = getattr(model, cache_name) + res = method(*args, **kwargs) + cache.append(res) + return res + + return wrapped + + +def _create_recorded_proxy_method(proxy, method_name, cache_name): + """ + Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values + during symbolic tracing. + """ + + def method(self, *args, **kwargs): + cache = getattr(self.tracer.root, cache_name) + res = cache.pop(0) + return res + + method.__name__ = method_name + bound_method = method.__get__(proxy, proxy.__class__) + setattr(proxy, method_name, bound_method) + + +def _wrap_method_for_model_tracing(model, method_name, cache_name): + """ + Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values + during symbolic tracing. + """ + + original_method = getattr(torch.Tensor, method_name) + + @functools.wraps(original_method) + def method(*args, **kwargs): + cache = getattr(model, cache_name) + res = cache.pop(0) + return res + + setattr(torch.Tensor, method_name, method) + + if method_name == "size": + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + +def _monkey_patch_tensor_methods_for_model_recording(model, method_names): + """ + Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference + before symbolic tracing. + """ + cache_names = dict() + original_methods = dict() + for method_name in method_names: + cache_name = f"cache_{method_name}" + cache_names[method_name] = cache_name + if not hasattr(torch.Tensor, method_name): + logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") + continue + original_methods[method_name] = getattr(torch.Tensor, method_name) + setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name)) + + if method_name == "size": + original_methods["shape"] = torch.Tensor.shape + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + return cache_names, original_methods + + +def _reset_tensor_methods(original_methods): + """Helper function that resets the monkey patched torch.Tensor methods to their original values.""" + for name, method in original_methods.items(): + setattr(torch.Tensor, name, method) + + class HFTracer(Tracer): """ - Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it - uses the HFProxy instead of the regular PyTorch torch.fx.Proxy. + Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the + regular PyTorch torch.fx.Proxy. """ + default_methods_to_record = {"__bool__", "size", "dim"} + def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): super().__init__() encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length - decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1 + decoder_sequence_length = ( + sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length + ) self.encoder_shape = [batch_size, encoder_sequence_length] self.decoder_shape = ( [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape) ) self.num_choices = num_choices if self.num_choices > 0: - self.encoder_shape[0] *= self.num_choices + self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length] + self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length] self.prev_module = None + self.recorded_methods = None def proxy(self, node: Node): - return HFProxy(node, self) + p = HFProxy(node, self) + if self.recorded_methods: + for method_name, cache_name in self.recorded_methods.items(): + _create_recorded_proxy_method(p, method_name, cache_name) + return p + + def _generate_dummy_input(self, model, input_name): + """Generates dummy input for model inference recording.""" + model_class = model.__class__ + device = model.device + inputs_dict = dict() + + if input_name in ["labels", "start_positions", "end_positions"]: + batch_size = self.encoder_shape[0] + if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device) + elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) + inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) + elif model_class in [ + *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING), + *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING), + *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), + ]: + inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) + elif model_class in [ + *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), + *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), + *get_values(MODEL_FOR_MASKED_LM_MAPPING), + *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), + GPT2DoubleHeadsModel, + ]: + inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device) + elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): + inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device) + else: + raise NotImplementedError(f"{model_class} not supported yet.") + + elif "mask" in input_name or "ids" in input_name: + shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape + inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device) + else: + shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape + shape += [model.config.hidden_size] + inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device) + + return inputs_dict + + def record(self, model, input_names, method_names=None): + """ + Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic + tracing. + """ + if method_names is None: + method_names = self.default_methods_to_record + + inputs = dict() + for input_name in input_names: + inputs.update(self._generate_dummy_input(model, input_name)) + + clone = copy.deepcopy(model) + cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names) + self.original_methods = original_methods + + clone(**inputs) + + _reset_tensor_methods(original_methods) + + self.recorded_methods = { + method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name) + } + + for cache_name in self.recorded_methods.values(): + setattr(model, cache_name, getattr(clone, cache_name)) + + def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: + sig = inspect.signature(root.forward) + input_names = sig.parameters.keys() - concrete_args.keys() + + self.record(root, input_names, method_names=method_names) + + for method_name, cache_name in self.recorded_methods.items(): + _wrap_method_for_model_tracing(root, method_name, cache_name) + + graph = super().trace(root, concrete_args=concrete_args) + + _reset_tensor_methods(self.original_methods) + + return graph def _insert_module_as_submodule(self, mod): """ @@ -202,6 +309,11 @@ class HFTracer(Tracer): self.prev_module = path return path + def create_arg(self, a: Any) -> Argument: + if isinstance(a, range): + return super().create_arg(list(a)) + return super().create_arg(a) + def symbolic_trace( model: PreTrainedModel, @@ -249,6 +361,7 @@ def symbolic_trace( concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) + traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index 81c5c48ccf..06e60d6925 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -229,6 +229,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 493cf7d555..2199ea282f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -600,9 +600,9 @@ class ModelTesterMixin: input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] if labels is not None: input_names.append("labels") - prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - model_output = model(**prepared_inputs) + model_output = model(**filtered_inputs) batch_size = input_ids.shape[0] encoder_sequence_length = input_ids.shape[1] @@ -615,26 +615,37 @@ class ModelTesterMixin: sequence_length=[encoder_sequence_length, decoder_sequence_length], ) - traced_output = traced_model(**prepared_inputs) + traced_output = traced_model(**filtered_inputs) else: - input_ids = inputs["input_ids"] - labels = inputs.get("labels", None) input_names = ["input_ids", "attention_mask", "token_type_ids"] + input_ids = inputs["input_ids"] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) if labels is not None: input_names.append("labels") - prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") - model_output = model(**prepared_inputs) + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = filtered_inputs.keys() - batch_size = input_ids.shape[0] + model_output = model(**filtered_inputs) - if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - sequence_length = input_ids.shape[2] - num_choices = input_ids.shape[1] - else: - sequence_length = input_ids.shape[1] + rank = len(input_ids.shape) + if rank == 2: + batch_size, sequence_length = input_ids.shape num_choices = -1 + elif rank == 3: + batch_size, num_choices, sequence_length = input_ids.shape + else: + raise NotImplementedError( + f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." + ) traced_model = symbolic_trace( model, @@ -643,14 +654,31 @@ class ModelTesterMixin: sequence_length=sequence_length, num_choices=num_choices, ) - traced_output = traced_model(**prepared_inputs) + traced_output = traced_model(**filtered_inputs) except RuntimeError: self.fail("Couldn't trace module.") + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) num_outputs = len(model_output) - outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs)) - self.assertTrue(outputs_are_close) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) def test_headmasking(self): if not self.test_head_masking: diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 0c5c4bcf68..269cadf957 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -208,6 +208,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) + fx_ready_model_classes = all_model_classes test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 10c456d877..25c5320815 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -399,6 +399,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () + fx_ready_model_classes = all_model_classes test_missing_keys = False test_model_parallel = True diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index ccf63c5e24..b4c8d185b1 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -276,6 +276,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else () all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes test_missing_keys = False test_pruning = False test_model_parallel = False diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index 5be4716d33..7a58e9f753 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # Copyright 2021 NVIDIA Corporation. All rights reserved. # @@ -282,6 +281,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_ready_model_classes = all_model_classes # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index ce5854d16a..3ebc770252 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model