From 0fe17f375a4f0fdd9aea260d0645ccfd4896e958 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 7 Feb 2022 22:25:33 +0100 Subject: [PATCH] FX tracing improvement (#14321) * Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions --- .../commands/add_new_model_like.py | 17 + src/transformers/file_utils.py | 2 +- src/transformers/modeling_utils.py | 46 +- .../models/albert/modeling_albert.py | 2 +- src/transformers/models/bert/modeling_bert.py | 4 +- .../models/electra/modeling_electra.py | 4 +- src/transformers/models/gpt2/modeling_gpt2.py | 8 +- .../models/gpt_neo/modeling_gpt_neo.py | 6 +- src/transformers/models/gptj/modeling_gptj.py | 6 +- .../models/layoutlm/modeling_layoutlm.py | 4 +- .../megatron_bert/modeling_megatron_bert.py | 4 +- .../models/mobilebert/modeling_mobilebert.py | 4 +- .../models/realm/modeling_realm.py | 4 +- .../models/roberta/modeling_roberta.py | 4 +- .../models/splinter/modeling_splinter.py | 4 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 4 +- src/transformers/utils/fx.py | 520 +++++++++--------- src/transformers/utils/fx_transformations.py | 321 ----------- tests/test_modeling_albert.py | 3 +- tests/test_modeling_bert.py | 3 +- tests/test_modeling_common.py | 43 +- tests/test_modeling_distilbert.py | 3 +- tests/test_modeling_electra.py | 5 +- tests/test_modeling_gpt2.py | 2 +- tests/test_modeling_gpt_neo.py | 2 +- tests/test_modeling_gptj.py | 2 +- tests/test_modeling_megatron_bert.py | 4 +- tests/test_modeling_mobilebert.py | 3 +- tests/test_modeling_roberta.py | 1 + tests/test_modeling_t5.py | 2 +- tests/test_modeling_xlnet.py | 1 + 31 files changed, 349 insertions(+), 689 deletions(-) delete mode 100644 src/transformers/utils/fx_transformations.py diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index e443a235c4..3ba5d71099 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -1189,6 +1189,16 @@ def create_new_model_like( if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) ] + def disable_fx_test(filename: Path) -> bool: + with open(filename) as fp: + content = fp.read() + new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content) + with open(filename, "w") as fp: + fp.write(new_content) + return content != new_content + + disabled_fx_test = False + for test_file in files_to_adapt: new_test_file_name = test_file.name.replace( old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased @@ -1201,6 +1211,13 @@ def create_new_model_like( dest_file=dest_file, add_copied_from=False, ) + disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file) + + if disabled_fx_test: + print( + "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " + "for your new model." + ) # 4. Add model to auto classes add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 1a4387d848..abc26cada9 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -322,7 +322,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. -TORCH_FX_REQUIRED_VERSION = version.parse("1.9") +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index abcc26052d..4845ff0743 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -247,6 +247,27 @@ class ModuleUtilsMixin: return encoder_extended_attention_mask + def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. @@ -271,26 +292,9 @@ class ModuleUtilsMixin: # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: - batch_size, seq_length = input_shape - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - causal_mask = torch.cat( - [ - torch.ones( - (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype - ), - causal_mask, - ], - axis=-1, - ) - - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + extended_attention_mask = self.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) else: extended_attention_mask = attention_mask[:, None, None, :] else: @@ -1861,7 +1865,7 @@ class Conv1D(nn.Module): def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(*size_out) + x = x.view(size_out) return x diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 6f443fb4f8..a54a3874ad 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -293,7 +293,7 @@ class AlbertAttention(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def prune_heads(self, heads): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 23dfbcee63..26c629f78f 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -252,7 +252,7 @@ class BertSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -341,7 +341,7 @@ class BertSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 054eff4be0..a61045f9c0 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -245,7 +245,7 @@ class ElectraSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -334,7 +334,7 @@ class ElectraSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b1988d7edf..59df99e8ab 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -193,7 +193,7 @@ class GPT2Attention(nn.Module): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -281,7 +281,7 @@ class GPT2Attention(nn.Module): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -915,7 +915,7 @@ class GPT2Model(GPT2PreTrainedModel): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7176cfa790..c516ca57a1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -173,7 +173,7 @@ class GPTNeoSelfAttention(nn.Module): Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): @@ -637,7 +637,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -891,7 +891,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[torch.arange(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 869014bee6..66163ad49f 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -107,7 +107,7 @@ class GPTJAttention(nn.Module): Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(*new_shape) + tensor = tensor.view(new_shape) if rotary: return tensor if len(tensor.shape) == 5: @@ -665,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel): hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,7 +945,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index d595fc8b51..bbdfeaac83 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -160,7 +160,7 @@ class LayoutLMSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -249,7 +249,7 @@ class LayoutLMSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index dbfb76cb5d..292b920bf5 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -223,7 +223,7 @@ class MegatronBertSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -312,7 +312,7 @@ class MegatronBertSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 2a90f1d92a..acf9607a73 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -237,7 +237,7 @@ class MobileBertSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -274,7 +274,7 @@ class MobileBertSelfAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 165e62c0ef..1189164138 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -260,7 +260,7 @@ class RealmSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -349,7 +349,7 @@ class RealmSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 426095e03e..88f0aa8d29 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -187,7 +187,7 @@ class RobertaSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -276,7 +276,7 @@ class RobertaSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index b982a38b62..3d15cc6825 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -127,7 +127,7 @@ class SplinterSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -216,7 +216,7 @@ class SplinterSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index cdea06ac57..cfeb788ec6 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -181,7 +181,7 @@ class XLMRobertaXLSelfAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -270,7 +270,7 @@ class XLMRobertaXLSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 23a2eb4c1f..f9cdc407ae 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -1,8 +1,24 @@ -import copy +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import inspect +import math import random -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from types import ModuleType +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch from packaging import version @@ -26,17 +42,11 @@ from .. import ( GPT2DoubleHeadsModel, PretrainedConfig, PreTrainedModel, + XLNetForQuestionAnswering, logging, ) from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available from ..models.auto import get_values -from .fx_transformations import ( - _cache_attributes, - _patch_arguments_, - _restore_attributes_, - transform_to_dynamic_input_, - transformation, -) logger = logging.get_logger(__name__) @@ -46,6 +56,7 @@ def _generate_supported_model_classes( model_name: Type[PretrainedConfig], supported_tasks: Optional[Union[str, List[str]]] = None, ) -> List[Type[PreTrainedModel]]: + model_config_class = CONFIG_MAPPING[model_name] task_mapping = { "default": MODEL_MAPPING, @@ -86,15 +97,10 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "gptj", "gpt_neo", "t5", -] - -_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [ - "albert", - "bert", - "distilbert", - "mobilebert", - "electra", - "megatron-bert", + "roberta", + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # "layoutlm", + # "xlnet", ] _REGULAR_SUPPORTED_MODELS = [] @@ -106,21 +112,11 @@ for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS: _SPECIAL_SUPPORTED_MODELS = [ GPT2DoubleHeadsModel, + # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # XLNetForQuestionAnswering, ] _SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS) -_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES: - if isinstance(item, dict): - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item)) - else: - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item)) - -_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = [] -_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple( - _REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES -) - class HFProxy(Proxy): """ @@ -134,6 +130,7 @@ class HFProxy(Proxy): if hasattr(self, "tracer") and self.tracer is not None: self.device = self.tracer.root.device self.dtype = next(self.tracer.root.parameters()).dtype + self.cache = None @property def shape(self): @@ -145,98 +142,124 @@ class HFProxy(Proxy): def __contains__(self, key): return False + def __eq__(self, other): + if self.cache is not None: + return self.cache == other + elif isinstance(other, HFProxy): + return True + else: + return super().__eq__(other) -def _wrap_method_for_model_recording(model, method_name, cache_name): - """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" - method = getattr(torch.Tensor, method_name) + def __ne__(self, other): + return not self == other - @functools.wraps(method) - def wrapped(*args, **kwargs): - if not hasattr(model, cache_name): - setattr(model, cache_name, []) - cache = getattr(model, cache_name) - res = method(*args, **kwargs) - cache.append(res) - return res + def __len__(self): + if self.cache is not None: + if isinstance(self.cache, int): + return self.cache + elif isinstance(self.cache, (torch.Size, list, tuple)): + return len(self.cache) + else: + return super().__len__(self) + return super().__len__(self) - return wrapped + def __torch_function__(self, orig_method, types, args=None, kwargs=None): + proxy = super().__torch_function__(orig_method, types, args=args, kwargs=kwargs) + proxy.cache = self.cache + return proxy -def _create_recorded_proxy_method(proxy, method_name, cache_name): +def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: + """Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +def _function_leaf_getter(func_name: str, mapping: Dict[str, Callable[..., Any]]) -> Callable[..., Any]: + @functools.wraps(mapping[func_name]) + def wrapper(*args, **kwargs): + return mapping[func_name](*args, **kwargs) + + return wrapper + + +def _create_recorded_proxy_method(proxy: HFProxy, method_name: str, cache_name: str, return_proxy: bool): """ Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values during symbolic tracing. """ - def method(self, *args, **kwargs): - cache = getattr(self.tracer.root, cache_name) - res = cache.pop(0) - return res - - method.__name__ = method_name - bound_method = method.__get__(proxy, proxy.__class__) - setattr(proxy, method_name, bound_method) - - -def _wrap_method_for_model_tracing(model, method_name, cache_name): - """ - Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values - during symbolic tracing. - """ - original_method = getattr(torch.Tensor, method_name) @functools.wraps(original_method) def method(*args, **kwargs): - cache = getattr(model, cache_name) + cache = getattr(args[0].tracer.root, cache_name) res = cache.pop(0) + if return_proxy: + proxy = args[0].__torch_function__( + original_method, + None, + args=args, + kwargs=kwargs, + ) + proxy.cache = res + return proxy return res - setattr(torch.Tensor, method_name, method) - - if method_name == "size": - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + method.__name__ = method_name + bound_method = method.__get__(proxy, proxy.__class__) + setattr(proxy, method_name, bound_method) -def _monkey_patch_tensor_methods_for_model_recording(model, method_names): - """ - Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference - before symbolic tracing. - """ - cache_names = dict() - original_methods = dict() - for method_name in method_names: - cache_name = f"cache_{method_name}" - cache_names[method_name] = cache_name - if not hasattr(torch.Tensor, method_name): - logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") - continue - original_methods[method_name] = getattr(torch.Tensor, method_name) - setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name)) - - if method_name == "size": - original_methods["shape"] = torch.Tensor.shape - setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) - - return cache_names, original_methods - - -def _reset_tensor_methods(original_methods): +def _reset_tensor_methods(original_methods: Dict[str, Callable[..., Any]]): """Helper function that resets the monkey patched torch.Tensor methods to their original values.""" for name, method in original_methods.items(): setattr(torch.Tensor, name, method) +def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): + if forbidden_values is None: + forbidden_values = [] + value = random.randint(low, high) + while value in forbidden_values: + value = random.randint(low, high) + return value + + class HFTracer(Tracer): """ Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the regular PyTorch torch.fx.Proxy. """ - default_methods_to_record = {"__bool__", "size", "dim"} + _DEFAULT_METHODS_TO_RECORD = {"__bool__": False, "size": True, "dim": False} + from transformers import modeling_utils - def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): - super().__init__() + _FUNCTIONS_TO_AUTOWRAP = { + torch: {"arange", "zeros", "ones", "full_like", "eye"}, + modeling_utils.ModuleUtilsMixin: {"create_extended_attention_mask_for_decoder"}, + } + + def __init__(self, autowrap_modules=(math,), autowrap_functions=(), enable_cpatching=False): + + # Loading the leaf functions register + self._leaf_functions_register = {} + for module, names in self._FUNCTIONS_TO_AUTOWRAP.items(): + for name in names: + self._register_leaf_function(module, name) + + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + # autowrap_functions = autowrap_functions + tuple( + # patched for (_, _, patched) in self._leaf_functions_register.values() + # ) + + super().__init__( + autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, enable_cpatching=enable_cpatching + ) if not is_torch_fx_available(): torch_version = version.parse(importlib_metadata.version("torch")) @@ -245,40 +268,107 @@ class HFTracer(Tracer): f"{TORCH_FX_REQUIRED_VERSION} is supported." ) - encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length - decoder_sequence_length = ( - sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length - ) - self.encoder_shape = [batch_size, encoder_sequence_length] - self.decoder_shape = ( - [batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape) - ) - self.num_choices = num_choices - if self.num_choices > 0: - self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length] - self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length] - self.prev_module = None self.recorded_methods = None - def proxy(self, node: Node): - p = HFProxy(node, self) - if self.recorded_methods: - for method_name, cache_name in self.recorded_methods.items(): - _create_recorded_proxy_method(p, method_name, cache_name) - return p + def _register_leaf_function(self, module: ModuleType, name: str): + """Registers the function called name in module as a leaf function.""" + orig_func = getattr(module, name) + patched_func = _function_to_leaf(orig_func) + patched_func.__module__ = __name__ + self._leaf_functions_register[name] = (module, orig_func, patched_func) - def _generate_dummy_input(self, model, input_name): + def _patch_leaf_functions_for_root(self, root: PreTrainedModel, restore: bool = False): + """Patches leaf functions specifically for root.""" + for name in self._leaf_functions_register: + module, orig_func, patched_func = self._leaf_functions_register[name] + if restore: + root.__class__.forward.__globals__.pop(name) + setattr(module, name, orig_func) + else: + root.__class__.forward.__globals__[name] = patched_func + leaf_getter = _function_leaf_getter(name, root.__class__.forward.__globals__) + leaf_getter.__module__ = __name__ + setattr(module, name, leaf_getter) + + def _method_is_called_in_leaf_module(self, module_ids: List[int]) -> bool: + """ + Finds out if the method (that is being recorded) is called inside a leaf module, this allows to not record + outputs that will not be encountered by the tracer. + """ + + currentframe = inspect.currentframe() + while currentframe: + if currentframe is None: + return False + module = currentframe.f_locals.get("self", None) + if id(module) in module_ids and self.is_leaf_module(module, "Not used anyway"): + return True + currentframe = currentframe.f_back + return False + + def _wrap_method_for_model_recording( + self, model: PreTrainedModel, method_name: str, cache_name: str, module_ids: List[int] + ): + """Helper function that wraps a torch.Tensor method to record its outputs during forward pass.""" + method = getattr(torch.Tensor, method_name) + + @functools.wraps(method) + def wrapped(*args, **kwargs): + if self._method_is_called_in_leaf_module(module_ids): + return method(*args, **kwargs) + if not hasattr(model, cache_name): + setattr(model, cache_name, []) + cache = getattr(model, cache_name) + res = method(*args, **kwargs) + cache.append(res) + return res + + return wrapped + + def _monkey_patch_tensor_methods_for_model_recording(self, model: PreTrainedModel, method_names: Iterable[str]): + """ + Helper function that patches torch.Tensor methods (specified by the method_names list) to record model + inference before symbolic tracing. + """ + cache_names = {} + original_methods = {} + module_ids = set(id(mod) for mod in model.modules()) + for method_name in method_names: + cache_name = f"cache_{method_name}" + cache_names[method_name] = cache_name + if not hasattr(torch.Tensor, method_name): + logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.") + continue + original_methods[method_name] = getattr(torch.Tensor, method_name) + setattr( + torch.Tensor, + method_name, + self._wrap_method_for_model_recording(model, method_name, cache_name, module_ids), + ) + + if method_name == "size": + original_methods["shape"] = torch.Tensor.shape + setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name))) + + return cache_names, original_methods + + def _generate_dummy_input( + self, model: PreTrainedModel, input_name: str, shape: List[int] + ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" model_class = model.__class__ device = model.device - inputs_dict = dict() + inputs_dict = {} if input_name in ["labels", "start_positions", "end_positions"]: - batch_size = self.encoder_shape[0] + batch_size = shape[0] if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) + elif model_class in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), + XLNetForQuestionAnswering, + ]: inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ @@ -288,59 +378,56 @@ class HFTracer(Tracer): ]: inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class in [ + *get_values(MODEL_FOR_PRETRAINING_MAPPING), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), GPT2DoubleHeadsModel, ]: - inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device) - elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): - inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device) + inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: raise NotImplementedError(f"{model_class} not supported yet.") elif "mask" in input_name or "ids" in input_name: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device) + inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) else: - shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape - shape += [model.config.hidden_size] - inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device) + shape_with_hidden_size = shape + [model.config.hidden_size] + inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device) return inputs_dict - def record(self, model, input_names, method_names=None): + def record(self, model: PreTrainedModel, input_names: List[str], method_names: Optional[Iterable[str]] = None): """ - Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic - tracing. + Records torch.Tensor method outputs (specified by method_names) that will then be used during symbolic tracing. """ if method_names is None: - method_names = self.default_methods_to_record + method_names = self._DEFAULT_METHODS_TO_RECORD + + # Creating a random input shape to generate dummy inputs. + batch_size = _generate_random_int() + sequence_length = _generate_random_int() + shape = [batch_size, sequence_length] + + if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + num_choices = _generate_random_int(low=2, high=5) + shape.insert(1, num_choices) inputs = {} for input_name in input_names: - inputs.update(self._generate_dummy_input(model, input_name)) + inputs.update(self._generate_dummy_input(model, input_name, shape)) - clone = copy.deepcopy(model) - cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names) + cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) self.original_methods = original_methods - clone(**inputs) - - # Useful because sometime the config is changed at inference time, for instance for - # classification tasks where config.problem_type can be set. - model.config = clone.config + model(**inputs) _reset_tensor_methods(original_methods) self.recorded_methods = { - method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name) + method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) } - for cache_name in self.recorded_methods.values(): - setattr(model, cache_name, getattr(clone, cache_name)) - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): if isinstance(attr_val, torch.nn.Parameter): for n, p in self.root.named_parameters(): @@ -357,7 +444,20 @@ class HFTracer(Tracer): return parameter_proxy_cache[n] return attr_val - def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph: + def proxy(self, node: Node): + p = HFProxy(node, self) + if self.recorded_methods: + for method_name, cache_name in self.recorded_methods.items(): + return_proxy = self._DEFAULT_METHODS_TO_RECORD[method_name] + _create_recorded_proxy_method(p, method_name, cache_name, return_proxy) + return p + + def trace( + self, + root: PreTrainedModel, + concrete_args: Optional[Dict[str, Any]] = None, + method_names: Optional[Iterable[str]] = None, + ) -> Graph: if concrete_args is None: concrete_args = {} @@ -366,11 +466,16 @@ class HFTracer(Tracer): self.record(root, input_names, method_names=method_names) - for method_name, cache_name in self.recorded_methods.items(): - _wrap_method_for_model_tracing(root, method_name, cache_name) + # TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. + autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] + self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + + self._patch_leaf_functions_for_root(root) graph = super().trace(root, concrete_args=concrete_args) + self._patch_leaf_functions_for_root(root, restore=True) + _reset_tensor_methods(self.original_methods) # TODO: keep this until necessary. @@ -388,7 +493,7 @@ class HFTracer(Tracer): return graph - def _insert_module_as_submodule(self, mod): + def _insert_module_as_submodule(self, mod: nn.Module) -> str: """ Helper method which tries to insert a module that was not declared as submodule. """ @@ -434,72 +539,19 @@ class HFTracer(Tracer): self.prev_module = path return path + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: + is_loss_module = m.__module__.startswith("torch.nn.modules.loss") + return (not is_loss_module) and super().is_leaf_module(m, module_qualified_name) + def create_arg(self, a: Any) -> Argument: if isinstance(a, range): return super().create_arg(list(a)) return super().create_arg(a) -@transformation -def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]: - """ - Prepares a GraphModule produced by symbolic_trace for retracing by: - - - Caching all the attributes specific to the way the model was initially traced - - Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes - For instance, the need to retrace a GraphModule can happen when applying quantization. - """ - attributes = _cache_attributes(gm) - _patch_arguments_(gm, gm.dynamic2static) - - return gm, attributes - - -def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]): - """Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes.""" - _restore_attributes_(gm, attributes) - # transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired - # behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values. - transform_to_dynamic_input_(gm, is_retracing=True) - _patch_arguments_(gm, gm.static2dynamic) - return gm - - -def retrace_graph_with( - gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None -) -> GraphModule: - """ - Retraces a GraphModule by either using a tracer or a function using a tracer (for instance - torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and - restoring anything necessary after the retrace. - """ - if tracer is None and func is None: - raise ValueError("Either a tracer or a function using a tracer must be provided.") - elif tracer is not None and func is not None: - raise ValueError("Either provide a tracer or a function using a tracer, but not both.") - else: - gm, attributes = prepare_for_retracing(gm) - tracing_func = tracer.trace if tracer else func - traced = tracing_func(gm) - restore_after_retracing_(traced, attributes) - return traced - - -def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): - if forbidden_values is None: - forbidden_values = [] - value = random.randint(low, high) - while value in forbidden_values: - value = random.randint(low, high) - return value - - def symbolic_trace( model: PreTrainedModel, input_names: Optional[List[str]] = None, - batch_size: int = 1, - sequence_length: Union[int, List[int], Tuple[int]] = (128, 128), - num_choices: int = -1, ) -> GraphModule: """ @@ -510,89 +562,33 @@ def symbolic_trace( The model to trace. input_names (`List[str]`, *optional*): The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead. - batch_size (`int`, *optional*, defaults to 1): - The batch size of the traced model inputs. - sequence_length (`int` or `List[int]]`): - The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence - lengths between the encoder and the decoder inputs, this must be `[encoder_sequence_length, - decoder_sequence_length]`. - num_choices (`int`, *optional*, defaults to -1): - The number of possible choices for a multiple choice task. Returns: `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. Example: - ```python - from transformers.utils.fx import symbolic_trace + ```python + from transformers.utils.fx import symbolic_trace - traced_model = symbolic_trace( - model, - input_names=["input_ids", "attention_mask", "token_type_ids"], - batch_size=1, - sequence_length=128, - ) - ```""" + traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"]) + ``` + """ if input_names is None: input_names = model.dummy_inputs.keys() sig = inspect.signature(model.forward) concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} - # Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes. - use_dynamic_batch_size = batch_size <= 0 - if isinstance(sequence_length, (list, tuple)): - use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0 - else: - use_dynamic_sequence_length = sequence_length <= 0 - - if use_dynamic_batch_size or use_dynamic_sequence_length: - forbidden_values = [ - model.config.num_attention_heads, - model.config.hidden_size, - model.config.hidden_size // model.config.num_attention_heads, - ] - if use_dynamic_batch_size: - batch_size = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(batch_size) - if use_dynamic_sequence_length: - encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - forbidden_values.append(encoder_sequence_length) - decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values) - sequence_length = [encoder_sequence_length, decoder_sequence_length] - if not isinstance(model, _SUPPORTED_MODELS): supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS)) raise NotImplementedError( f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}" ) - if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance( - model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES - ): - supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES)) - raise NotImplementedError( - f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}" - ) # Tracing. - tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices) - + tracer = HFTracer() traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) - traced.config = copy.deepcopy(model.config) - traced.num_choices = num_choices - traced.dummy_inputs = {} - - for name in input_names: - traced.dummy_inputs.update(tracer._generate_dummy_input(model, name)) - - traced.use_dynamic_batch_size = use_dynamic_batch_size - traced.use_dynamic_sequence_length = use_dynamic_sequence_length - traced.static_batch_size = batch_size - traced.static_sequence_length = sequence_length - - transform_to_dynamic_input_(traced) - return traced diff --git a/src/transformers/utils/fx_transformations.py b/src/transformers/utils/fx_transformations.py deleted file mode 100644 index 3e181617af..0000000000 --- a/src/transformers/utils/fx_transformations.py +++ /dev/null @@ -1,321 +0,0 @@ -import copy -import functools -import operator -from inspect import signature -from typing import Any, Callable, Dict, Optional, Union - -import torch -from torch.fx import Graph, GraphModule, Node - - -# Torch FX transformation convention: -# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation -# - transformations that are inplace have a name ending with "_" - - -def _cache_attributes(gm: GraphModule) -> Dict[str, Any]: - attributes_to_keep = [ - "config", - "num_choices", - "dummy_inputs", - "use_dynamic_batch_size", - "use_dynamic_sequence_length", - "static_batch_size", - "static_sequence_length", - "static2dynamic", - "dynamic2static", - ] - attributes = {k: getattr(gm, k, None) for k in attributes_to_keep} - return attributes - - -def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]): - for name, attr in attributes.items(): - setattr(gm, name, attr) - - -def deepcopy_graph(gm: GraphModule) -> GraphModule: - """ - Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was - traced with dynamic axes, and what were the values if that is the case. - """ - - # First, create a copy of the module without the graph. - graph = gm.__dict__.pop("_graph") - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(gm.__dict__) - gm.__dict__["_graph"] = graph - - # Then, copy the graph. - val_map = {} - graph_clone = Graph() - output_val = graph_clone.graph_copy(graph, val_map=val_map) - graph_clone.output(output_val) - - # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies. - # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule. - clone = gm.__class__(fake_mod, graph_clone) - - # Restore the dynamic axes related attributes to the clone. - attributes = _cache_attributes(gm) - attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()} - attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()} - _restore_attributes_(clone, attributes) - - return clone - - -def transformation(func): - """ - Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the - original. - """ - - def map_fn(arg): - if isinstance(arg, GraphModule): - return deepcopy_graph(arg) - return arg - - @functools.wraps(func) - def wrapper(*args, **kwargs): - new_args = tuple(map_fn(arg) for arg in args) - new_kwargs = {k: map_fn(v) for k, v in kwargs.items()} - return func(*new_args, **new_kwargs) - - wrapper._is_transformation = True - - return wrapper - - -def compose_transformations( - *args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False -) -> GraphModule: - """ - Allows to compose transformations together and takes of: - - 1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that - are decorated with @transformation (which means that they are not modifying the original GraphModule) are - unwrapped to make them inplace. - 2. Linting and recompiling only at the end of the composition for performance purposes. - """ - args = list(args) - if not inplace: - args.insert(0, deepcopy_graph) - - for i, transformation in enumerate(args[:-1]): - sig = signature(transformation) - - # Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is - # already handled by this function. - if getattr(transformation, "_is_transformation", False): - transformation = transformation.__wrapped__ - - # Linting and recompiling only after the last transformation applied to make composition efficient. - if "lint_and_recompile" in sig.parameters: - args[i] = functools.partial(transformation, lint_and_recompile=False) - - def reduce_func(f, g): - def compose_f_and_g(gm): - output_g = g(gm) - if output_g is None: - output_g = gm - output_f = f(output_g) - if output_f is None: - output_f = gm - return output_f - - return compose_f_and_g - - return functools.reduce(reduce_func, reversed(args), lambda x: x) - - -def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True): - """Removes all the unused nodes in a GraphModule.""" - graph = gm.graph - for node in graph.nodes: - if not node.users and node.op not in ["placeholder", "output"]: - graph.erase_node(node) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the batch size dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - batch_size_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names: - with graph.inserting_after(node): - batch_size_node = graph.call_method("size", args=(node, 0)) - - if batch_size_node is None: - raise ValueError("Could not insert the node that computes the batch size") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[batch_size_node.name] = None - - return batch_size_node - - -def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node: - """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model.""" - graph = gm.graph - input_names = set(gm.dummy_inputs.keys()) - encoder_sequence_length_node = None - for node in graph.nodes: - if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name: - with graph.inserting_after(node): - # There are two cases to handle: - # 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the - # input shapes is [batch_size, sequence_length] => index 1 - # 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input - # shape is [batch_size, num_choices, sequence_length] => index 2 - encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2)) - - if encoder_sequence_length_node is None: - raise ValueError("Could not insert the node that computes the encoder sequence length") - - if lint_and_recompile: - graph.lint() - gm.recompile() - - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[encoder_sequence_length_node.name] = None - - return encoder_sequence_length_node - - -def _change_view_methods_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the - batch_size / sequence_length nodes. - """ - graph = gm.graph - for node in graph.nodes: - if node.op == "call_method" and node.target == "view": - if isinstance(node.args[1], tuple): - node.args = (node.args[0], *node.args[1]) - node.args = tuple((mapping.get(arg, arg) for arg in node.args)) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_getitem_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """Patches getitem nodes by replacing current arguments to their corresponding values in mapping.""" - # TODO: combine this with the patch_argument function which seems to do almost the same thing. - graph = gm.graph - for node in graph.nodes: - if node.op == "call_function" and node.target == operator.getitem: - indices = node.args[1] - if isinstance(indices, tuple): - new_indices = [] - for idx in indices: - if isinstance(idx, slice): - new_indices.append( - slice( - mapping.get(idx.start, idx.start), - mapping.get(idx.stop, idx.stop), - mapping.get(idx.step, idx.step), - ) - ) - elif isinstance(idx, int): - new_indices.append(mapping.get(idx, idx)) - else: - new_indices.append(idx) - - node.args = (node.args[0], tuple(new_indices)) - else: - node.args = (node.args[0], mapping.get(node.args[1], node.args[1])) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def _patch_arguments_( - gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True -): - """ - Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples - and slices). - """ - - def _patch_slice(s, mapping): - return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step)) - - graph = gm.graph - supported_types = (Node, str, int, float) - for node in graph.nodes: - new_args = [] - for arg in node.args: - if isinstance(arg, tuple): - new_arg = [] - for a in arg: - if isinstance(a, slice): - new_arg.append(_patch_slice(a, mapping)) - else: - new_arg.append(mapping.get(a, a)) - new_args.append(tuple(new_arg)) - elif isinstance(arg, slice): - new_args.append(_patch_slice(arg, mapping)) - elif isinstance(arg, supported_types): - new_args.append(mapping.get(arg, arg)) - else: - new_args.append(arg) - node.args = tuple(new_args) - - if lint_and_recompile: - graph.lint() - gm.recompile() - - -def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False): - """Transformation that enables traced models to perform inference on dynamic input shapes.""" - graph = gm.graph - static2dynamic = {} - - # Inserting the nodes that will fetch the batch size and sequence lengths dynamically. - if gm.use_dynamic_batch_size: - batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_batch_size] = batch_size_node - if gm.num_choices > 0: - with graph.inserting_after(batch_size_node): - static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function( - operator.mul, args=(batch_size_node, gm.num_choices) - ) - # Useful when retracing for quantization. - if hasattr(gm, "_qconfig_map"): - gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None - - if gm.use_dynamic_sequence_length: - encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False) - static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node - - # TODO: do the same for the decoder. - pass - - _change_view_methods_(gm, static2dynamic, lint_and_recompile=False) - _patch_getitem_(gm, static2dynamic, lint_and_recompile=False) - - remove_unused_nodes_(gm, lint_and_recompile=False) - - graph.lint() - gm.recompile() - - gm.static2dynamic = static2dynamic - gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()} diff --git a/tests/test_modeling_albert.py b/tests/test_modeling_albert.py index d16dcadd5e..ab5595f4b6 100644 --- a/tests/test_modeling_albert.py +++ b/tests/test_modeling_albert.py @@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 7a66285097..7b8738fd60 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 420b775031..22cbd66b81 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -116,8 +116,7 @@ class ModelTesterMixin: model_tester = None all_model_classes = () all_generative_model_classes = () - fx_ready_model_classes = () - fx_dynamic_ready_model_classes = () + fx_compatible = False test_torchscript = True test_pruning = True test_resize_embeddings = True @@ -666,19 +665,14 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) - def test_torch_fx_dynamic_axes(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True) - - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False): - if not is_torch_fx_available(): + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.return_dict = False - model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes - for model_class in model_classes: + for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) model.eval() @@ -687,8 +681,6 @@ class ModelTesterMixin: try: if model.config.is_encoder_decoder: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - input_ids = inputs["input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] labels = inputs.get("labels", None) input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] if labels is not None: @@ -697,17 +689,7 @@ class ModelTesterMixin: model_output = model(**filtered_inputs) - batch_size = input_ids.shape[0] - encoder_sequence_length = input_ids.shape[1] - decoder_sequence_length = decoder_attention_mask.shape[1] - - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1, - ) - + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: input_names = ["input_ids", "attention_mask", "token_type_ids"] @@ -729,23 +711,12 @@ class ModelTesterMixin: model_output = model(**filtered_inputs) rank = len(input_ids.shape) - if rank == 2: - batch_size, sequence_length = input_ids.shape - num_choices = -1 - elif rank == 3: - batch_size, num_choices, sequence_length = input_ids.shape - else: + if rank not in [2, 3]: raise NotImplementedError( f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." ) - traced_model = symbolic_trace( - model, - input_names, - batch_size=batch_size if not dynamic_axes else -1, - sequence_length=sequence_length if not dynamic_axes else -1, - num_choices=num_choices, - ) + traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) except RuntimeError: diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index ee8a8cbd3d..b81e42bcf1 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -209,8 +209,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else None ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True test_pruning = True test_torchscript = True test_resize_embeddings = True diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index be19f8d610..065d596826 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -369,10 +369,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else () - - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index ef51c815e4..cd13be27bb 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -433,7 +433,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ) all_generative_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_missing_keys = False test_model_parallel = True diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index a8e5b4babc..b8f942ef17 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -372,7 +372,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase (GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_missing_keys = False test_pruning = False test_model_parallel = False diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index dd743b80d7..d6b9f92926 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -363,7 +363,7 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True test_pruning = False test_missing_keys = False test_model_parallel = False diff --git a/tests/test_modeling_megatron_bert.py b/tests/test_modeling_megatron_bert.py index a7f47ddea3..7ac507988f 100644 --- a/tests/test_modeling_megatron_bert.py +++ b/tests/test_modeling_megatron_bert.py @@ -283,9 +283,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes - + fx_compatible = True # test_resize_embeddings = False test_head_masking = False diff --git a/tests/test_modeling_mobilebert.py b/tests/test_modeling_mobilebert.py index 716714157a..6ca14526a6 100644 --- a/tests/test_modeling_mobilebert.py +++ b/tests/test_modeling_mobilebert.py @@ -269,8 +269,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) - fx_ready_model_classes = all_model_classes - fx_dynamic_ready_model_classes = all_model_classes + fx_compatible = True # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index d0a8aab6b7..1a55fda152 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -356,6 +356,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas else () ) all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () + fx_compatible = True def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 575850aa90..c0b5739bca 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -509,7 +509,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () - fx_ready_model_classes = all_model_classes + fx_compatible = True all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () test_pruning = False test_torchscript = True diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 5516b28e17..f4e90fbe77 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) all_generative_model_classes = ( (XLNetLMHeadModel,) if is_torch_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_pruning = False # XLNet has 2 QA models -> need to manually set the correct labels for one of them here