From d4e4efce68a5d18ca3175475b59051dec336fa6b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 5 Oct 2021 14:19:47 +0200 Subject: [PATCH] Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579) * Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert) * Sanity checks before tracing that make sure the model to trace is supported * Adapted to PyTorch 1.9 Co-authored-by: Michael Benayoun --- src/transformers/file_utils.py | 2 +- .../models/distilbert/modeling_distilbert.py | 2 +- src/transformers/utils/fx.py | 238 ++++++++++++- src/transformers/utils/fx_transformations.py | 321 ++++++++++++++++++ tests/test_modeling_albert.py | 1 + tests/test_modeling_bert.py | 1 + tests/test_modeling_common.py | 19 +- tests/test_modeling_distilbert.py | 1 + tests/test_modeling_electra.py | 1 + tests/test_modeling_megatron_bert.py | 1 + tests/test_modeling_mobilebert.py | 1 + 11 files changed, 571 insertions(+), 17 deletions(-) create mode 100644 src/transformers/utils/fx_transformations.py diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index cd18a2681a..6763b6a67f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -280,7 +280,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. -TORCH_FX_REQUIRED_VERSION = version.parse("1.8") +TORCH_FX_REQUIRED_VERSION = version.parse("1.9") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index bcfab1fa88..d137e10fdb 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -204,7 +204,7 @@ class MultiHeadSelfAttention(nn.Module): q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) - scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 2ae7d6c1c0..af35d66bbf 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,7 +1,8 @@ import copy import functools import inspect -from typing import Any, Dict, List, Optional, Union +import random +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from packaging import version @@ -9,9 +10,8 @@ from torch import nn from torch.fx import Graph, GraphModule, Node, Proxy, Tracer from torch.fx.node import Argument -from transformers.file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available - from .. import ( + CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, @@ -22,16 +22,106 @@ from .. import ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_MAPPING, GPT2DoubleHeadsModel, + PretrainedConfig, PreTrainedModel, logging, ) +from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available from ..models.auto import get_values +from .fx_transformations import ( + _cache_attributes, + _patch_arguments_, + _restore_attributes_, + transform_to_dynamic_input_, + transformation, +) logger = logging.get_logger(__name__) +def _generate_supported_model_classes( + model_name: Type[PretrainedConfig], + supported_tasks: Optional[Union[str, List[str]]] = None, +) -> List[Type[PreTrainedModel]]: + model_config_class = CONFIG_MAPPING[model_name] + task_mapping = { + "default": MODEL_MAPPING, + "pretraining": MODEL_FOR_PRETRAINING_MAPPING, + "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + "masked-lm": MODEL_FOR_MASKED_LM_MAPPING, + "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING, + "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, + "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + } + + if supported_tasks is None: + supported_tasks = task_mapping.keys() + if isinstance(supported_tasks, str): + supported_tasks = [supported_tasks] + + model_classes = [] + for task in supported_tasks: + model_class = task_mapping[task].get(model_config_class, None) + if model_class: + model_classes.append(model_class) + + return model_classes + + +_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ + "albert", + "bert", + "distilbert", + "mobilebert", + "electra", + "megatron-bert", + "gpt2", + "gptj", + "gpt_neo", + "t5", +] + +_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [ + "albert", + "bert", + "distilbert", + "mobilebert", + "electra", + "megatron-bert", +] + +_REGULAR_SUPPORTED_MODELS = [] +for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: + if isinstance(item, dict): + _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item)) + else: + _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item)) + +_SPECIAL_SUPPORTED_MODELS = [ + GPT2DoubleHeadsModel, +] +_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS) + +_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] +for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES: + if isinstance(item, dict): + _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item)) + else: + _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item)) + +_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] +_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple( + _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES +) + + class HFProxy(Proxy): """ Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing @@ -228,7 +318,7 @@ class HFTracer(Tracer): if method_names is None: method_names = self.default_methods_to_record - inputs = dict() + inputs = {} for input_name in input_names: inputs.update(self._generate_dummy_input(model, input_name)) @@ -251,6 +341,22 @@ class HFTracer(Tracer): for cache_name in self.recorded_methods.values(): setattr(model, cache_name, getattr(clone, cache_name)) + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): + if isinstance(attr_val, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if attr_val is p: + if n not in parameter_proxy_cache: + parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {}) + return parameter_proxy_cache[n] + # TODO: condition this on wether dynamic axes were requested. + if isinstance(attr_val, torch.Tensor): + for n, p in self.root.named_buffers(): + if attr_val is p: + if n not in parameter_proxy_cache: + parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {}) + return parameter_proxy_cache[n] + return attr_val + def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: sig = inspect.signature(root.forward) input_names = sig.parameters.keys() - concrete_args.keys() @@ -264,6 +370,19 @@ class HFTracer(Tracer): _reset_tensor_methods(self.original_methods) + # TODO: keep this until necessary. + # This is necessary because concrete args are added as input to the traced module since + # https://github.com/pytorch/pytorch/pull/55888. + # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet. + for node in graph.nodes: + if node.op == "placeholder": + # Removing default values for inputs as the forward pass will fail with them. + if node.target in input_names: + node.args = () + # It is a concrete arg so it is not used and should be removed. + else: + graph.erase_node(node) + return graph def _insert_module_as_submodule(self, mod): @@ -295,7 +414,7 @@ class HFTracer(Tracer): if path is None: path = self._insert_module_as_submodule(mod) if path is None: - raise NameError("module is not installed as a submodule") + raise NameError(f"Module named {mod._get_name()} is not installed as a submodule") self.prev_module = path return path @@ -308,7 +427,7 @@ class HFTracer(Tracer): return n path = self._insert_module_as_submodule(mod) if path is None: - raise NameError("module is not installed as a submodule") + raise NameError(f"Module {mod._get_name()} is not installed as a submodule") self.prev_module = path return path @@ -318,11 +437,65 @@ class HFTracer(Tracer): return super().create_arg(a) +@transformation +def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: + """ + Prepares a GraphModule produced by symbolic_trace for retracing by: + + - Caching all the attributes specific to the way the model was initially traced + - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes + For instance, the need to retrace a GraphModule can happen when applying quantization. + """ + attributes = _cache_attributes(gm) + _patch_arguments_(gm, gm.dynamic2static) + + return gm, attributes + + +def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): + """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" + _restore_attributes_(gm, attributes) + # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired + # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. + transform_to_dynamic_input_(gm, is_retracing=True) + _patch_arguments_(gm, gm.static2dynamic) + return gm + + +def retrace_graph_with( + gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None +) -> GraphModule: + """ + Retraces a GraphModule by either using a tracer or a function using a tracer (for instance + torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and + restoring anything necessary after the retrace. + """ + if tracer is None and func is None: + raise ValueError("Either a tracer or a function using a tracer must be provided.") + elif tracer is not None and func is not None: + raise ValueError("Either provide a tracer or a function using a tracer, but not both.") + else: + gm, attributes = prepare_for_retracing(gm) + tracing_func = tracer.trace if tracer else func + traced = tracing_func(gm) + restore_after_retracing_(traced, attributes) + return traced + + +def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): + if forbidden_values is None: + forbidden_values = [] + value = random.randint(low, high) + while value in forbidden_values: + value = random.randint(low, high) + return value + + def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, batch_size: int = 1, - sequence_length: Union[int, List[int]] = [128, 128], + sequence_length: Union[int, List[int], Tuple[int]] = (128, 128), num_choices: int = -1, ) -> GraphModule: @@ -360,12 +533,61 @@ def symbolic_trace( input_names = model.dummy_inputs.keys() sig = inspect.signature(model.forward) - # TODO: how to handle the case of the "return_dict" parameter. concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} + # Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes. + use_dynamic_batch_size = batch_size <= 0 + if isinstance(sequence_length, (list, tuple)): + use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0 + else: + use_dynamic_sequence_length = sequence_length <= 0 + + if use_dynamic_batch_size or use_dynamic_sequence_length: + forbidden_values = [ + model.config.num_attention_heads, + model.config.hidden_size, + model.config.hidden_size // model.config.num_attention_heads, + ] + if use_dynamic_batch_size: + batch_size = _generate_random_int(forbidden_values=forbidden_values) + forbidden_values.append(batch_size) + if use_dynamic_sequence_length: + encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) + forbidden_values.append(encoder_sequence_length) + decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) + sequence_length = [encoder_sequence_length, decoder_sequence_length] + + if not isinstance(model, _SUPPORTED_MODELS): + supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) + raise NotImplementedError( + f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" + ) + if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( + model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES + ): + supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) + raise NotImplementedError( + f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" + ) + + # Tracing. tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) + traced.config = copy.deepcopy(model.config) + traced.num_choices = num_choices + traced.dummy_inputs = {} + + for name in input_names: + traced.dummy_inputs.update(tracer._generate_dummy_input(model, name)) + + traced.use_dynamic_batch_size = use_dynamic_batch_size + traced.use_dynamic_sequence_length = use_dynamic_sequence_length + traced.static_batch_size = batch_size + traced.static_sequence_length = sequence_length + + transform_to_dynamic_input_(traced) + return traced diff --git a/src/transformers/utils/fx_transformations.py b/src/transformers/utils/fx_transformations.py new file mode 100644 index 0000000000..3e181617af --- /dev/null +++ b/src/transformers/utils/fx_transformations.py @@ -0,0 +1,321 @@ +import copy +import functools +import operator +from inspect import signature +from typing import Any, Callable, Dict, Optional, Union + +import torch +from torch.fx import Graph, GraphModule, Node + + +# Torch FX transformation convention: +# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation +# - transformations that are inplace have a name ending with "_" + + +def _cache_attributes(gm: GraphModule) -> Dict[str, Any]: + attributes_to_keep = [ + "config", + "num_choices", + "dummy_inputs", + "use_dynamic_batch_size", + "use_dynamic_sequence_length", + "static_batch_size", + "static_sequence_length", + "static2dynamic", + "dynamic2static", + ] + attributes = {k: getattr(gm, k, None) for k in attributes_to_keep} + return attributes + + +def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]): + for name, attr in attributes.items(): + setattr(gm, name, attr) + + +def deepcopy_graph(gm: GraphModule) -> GraphModule: + """ + Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was + traced with dynamic axes, and what were the values if that is the case. + """ + + # First, create a copy of the module without the graph. + graph = gm.__dict__.pop("_graph") + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(gm.__dict__) + gm.__dict__["_graph"] = graph + + # Then, copy the graph. + val_map = {} + graph_clone = Graph() + output_val = graph_clone.graph_copy(graph, val_map=val_map) + graph_clone.output(output_val) + + # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. + # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. + clone = gm.__class__(fake_mod, graph_clone) + + # Restore the dynamic axes related attributes to the clone. + attributes = _cache_attributes(gm) + attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} + attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} + _restore_attributes_(clone, attributes) + + return clone + + +def transformation(func): + """ + Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the + original. + """ + + def map_fn(arg): + if isinstance(arg, GraphModule): + return deepcopy_graph(arg) + return arg + + @functools.wraps(func) + def wrapper(*args, **kwargs): + new_args = tuple(map_fn(arg) for arg in args) + new_kwargs = {k: map_fn(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + wrapper._is_transformation = True + + return wrapper + + +def compose_transformations( + *args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False +) -> GraphModule: + """ + Allows to compose transformations together and takes of: + + 1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that + are decorated with @transformation (which means that they are not modifying the original GraphModule) are + unwrapped to make them inplace. + 2. Linting and recompiling only at the end of the composition for performance purposes. + """ + args = list(args) + if not inplace: + args.insert(0, deepcopy_graph) + + for i, transformation in enumerate(args[:-1]): + sig = signature(transformation) + + # Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is + # already handled by this function. + if getattr(transformation, "_is_transformation", False): + transformation = transformation.__wrapped__ + + # Linting and recompiling only after the last transformation applied to make composition efficient. + if "lint_and_recompile" in sig.parameters: + args[i] = functools.partial(transformation, lint_and_recompile=False) + + def reduce_func(f, g): + def compose_f_and_g(gm): + output_g = g(gm) + if output_g is None: + output_g = gm + output_f = f(output_g) + if output_f is None: + output_f = gm + return output_f + + return compose_f_and_g + + return functools.reduce(reduce_func, reversed(args), lambda x: x) + + +def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True): + """Removes all the unused nodes in a GraphModule.""" + graph = gm.graph + for node in graph.nodes: + if not node.users and node.op not in ["placeholder", "output"]: + graph.erase_node(node) + + if lint_and_recompile: + graph.lint() + gm.recompile() + + +def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: + """Inserts a node that retrieves the batch size dynamically from the input of the model.""" + graph = gm.graph + input_names = set(gm.dummy_inputs.keys()) + batch_size_node = None + for node in graph.nodes: + if node.op == "placeholder" and node.name in input_names: + with graph.inserting_after(node): + batch_size_node = graph.call_method("size", args=(node, 0)) + + if batch_size_node is None: + raise ValueError("Could not insert the node that computes the batch size") + + if lint_and_recompile: + graph.lint() + gm.recompile() + + # Useful when retracing for quantization. + if hasattr(gm, "_qconfig_map"): + gm._qconfig_map[batch_size_node.name] = None + + return batch_size_node + + +def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: + """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model.""" + graph = gm.graph + input_names = set(gm.dummy_inputs.keys()) + encoder_sequence_length_node = None + for node in graph.nodes: + if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name: + with graph.inserting_after(node): + # There are two cases to handle: + # 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the + # input shapes is [batch_size, sequence_length] => index 1 + # 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input + # shape is [batch_size, num_choices, sequence_length] => index 2 + encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2)) + + if encoder_sequence_length_node is None: + raise ValueError("Could not insert the node that computes the encoder sequence length") + + if lint_and_recompile: + graph.lint() + gm.recompile() + + # Useful when retracing for quantization. + if hasattr(gm, "_qconfig_map"): + gm._qconfig_map[encoder_sequence_length_node.name] = None + + return encoder_sequence_length_node + + +def _change_view_methods_( + gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True +): + """ + Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the + batch_size / sequence_length nodes. + """ + graph = gm.graph + for node in graph.nodes: + if node.op == "call_method" and node.target == "view": + if isinstance(node.args[1], tuple): + node.args = (node.args[0], *node.args[1]) + node.args = tuple((mapping.get(arg, arg) for arg in node.args)) + + if lint_and_recompile: + graph.lint() + gm.recompile() + + +def _patch_getitem_( + gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True +): + """Patches getitem nodes by replacing current arguments to their corresponding values in mapping.""" + # TODO: combine this with the patch_argument function which seems to do almost the same thing. + graph = gm.graph + for node in graph.nodes: + if node.op == "call_function" and node.target == operator.getitem: + indices = node.args[1] + if isinstance(indices, tuple): + new_indices = [] + for idx in indices: + if isinstance(idx, slice): + new_indices.append( + slice( + mapping.get(idx.start, idx.start), + mapping.get(idx.stop, idx.stop), + mapping.get(idx.step, idx.step), + ) + ) + elif isinstance(idx, int): + new_indices.append(mapping.get(idx, idx)) + else: + new_indices.append(idx) + + node.args = (node.args[0], tuple(new_indices)) + else: + node.args = (node.args[0], mapping.get(node.args[1], node.args[1])) + + if lint_and_recompile: + graph.lint() + gm.recompile() + + +def _patch_arguments_( + gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True +): + """ + Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples + and slices). + """ + + def _patch_slice(s, mapping): + return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step)) + + graph = gm.graph + supported_types = (Node, str, int, float) + for node in graph.nodes: + new_args = [] + for arg in node.args: + if isinstance(arg, tuple): + new_arg = [] + for a in arg: + if isinstance(a, slice): + new_arg.append(_patch_slice(a, mapping)) + else: + new_arg.append(mapping.get(a, a)) + new_args.append(tuple(new_arg)) + elif isinstance(arg, slice): + new_args.append(_patch_slice(arg, mapping)) + elif isinstance(arg, supported_types): + new_args.append(mapping.get(arg, arg)) + else: + new_args.append(arg) + node.args = tuple(new_args) + + if lint_and_recompile: + graph.lint() + gm.recompile() + + +def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False): + """Transformation that enables traced models to perform inference on dynamic input shapes.""" + graph = gm.graph + static2dynamic = {} + + # Inserting the nodes that will fetch the batch size and sequence lengths dynamically. + if gm.use_dynamic_batch_size: + batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False) + static2dynamic[gm.static_batch_size] = batch_size_node + if gm.num_choices > 0: + with graph.inserting_after(batch_size_node): + static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function( + operator.mul, args=(batch_size_node, gm.num_choices) + ) + # Useful when retracing for quantization. + if hasattr(gm, "_qconfig_map"): + gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None + + if gm.use_dynamic_sequence_length: + encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False) + static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node + + # TODO: do the same for the decoder. + pass + + _change_view_methods_(gm, static2dynamic, lint_and_recompile=False) + _patch_getitem_(gm, static2dynamic, lint_and_recompile=False) + + remove_unused_nodes_(gm, lint_and_recompile=False) + + graph.lint() + gm.recompile() + + gm.static2dynamic = static2dynamic + gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()} diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index 5be455b630..a8228d3e13 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -232,6 +232,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index a029d9d47e..7b7f02a553 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -445,6 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b61cf834fb..7c1ca8dd44 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -92,6 +92,7 @@ class ModelTesterMixin: all_model_classes = () all_generative_model_classes = () fx_ready_model_classes = () + fx_dynamic_ready_model_classes = () test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -607,14 +608,19 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + def test_torch_fx_dynamic_axes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True) + + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False): if not is_torch_fx_available(): return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False - for model_class in self.fx_ready_model_classes: + model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes + for model_class in model_classes: model = model_class(config=configs_no_init) model.to(torch_device) model.eval() @@ -640,12 +646,11 @@ class ModelTesterMixin: traced_model = symbolic_trace( model, input_names, - batch_size=batch_size, - sequence_length=[encoder_sequence_length, decoder_sequence_length], + batch_size=batch_size if not dynamic_axes else -1, + sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1, ) traced_output = traced_model(**filtered_inputs) - else: input_names = ["input_ids", "attention_mask", "token_type_ids"] input_ids = inputs["input_ids"] @@ -679,8 +684,8 @@ class ModelTesterMixin: traced_model = symbolic_trace( model, input_names, - batch_size=batch_size, - sequence_length=sequence_length, + batch_size=batch_size if not dynamic_axes else -1, + sequence_length=sequence_length if not dynamic_axes else -1, num_choices=num_choices, ) traced_output = traced_model(**filtered_inputs) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 87ebaa22ee..ed7fba94bb 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -210,6 +210,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): else None ) fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index a19af17f52..2b19bb4a5d 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index 5a06d57a9e..a7f47ddea3 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -284,6 +284,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index ec90b9b1b7..23eb5a9c5e 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -270,6 +270,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): else () ) fx_ready_model_classes = all_model_classes + fx_dynamic_ready_model_classes = all_model_classes test_sequence_classification_problem_types = True # special case for ForPreTraining model