From 4e2f4a92ddbd26cbd6cbbcf741016e04c43ff30d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 29 Jul 2022 16:12:27 +0200 Subject: [PATCH] [FX] Symbolic trace for Bloom (#18356) * Bloom model can now be traced * Bloom traced model can be torch scripted and serialized * Bloom can be traced with variable keyword arguments * Enable XLNet support * Disable XLNet for now --- .../models/bloom/modeling_bloom.py | 6 ++-- src/transformers/utils/fx.py | 32 +++++++++---------- tests/models/bloom/test_modeling_bloom.py | 2 +- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index ed5b50a77d..afa289afe5 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -244,7 +244,7 @@ class BloomAttention(nn.Module): new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) # fused_qkv = fused_qkv.transpose(1, 0) - fused_qkv = fused_qkv.reshape(*new_tensor_shape) + fused_qkv = fused_qkv.reshape(new_tensor_shape) # fused_qkv = fused_qkv.permute(0, 2, 1, 3) return torch.split(fused_qkv, self.head_dim, -1) @@ -306,7 +306,7 @@ class BloomAttention(nn.Module): attn_weights = (attention_scores * self.layer_number) + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) - attention_probs = attention_probs * (~attention_mask.bool()) + attention_probs = attention_probs * (~attention_mask.to(torch.bool)) # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs) @@ -314,7 +314,7 @@ class BloomAttention(nn.Module): attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(*matmul_result.shape) + attention_probs_reshaped = attention_probs.view(matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm( diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 6fed5808f8..7eae67ba70 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -98,6 +98,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "bert", "blenderbot", "blenderbot-small", + "bloom", "clip", "deberta", "deberta-v2", @@ -127,8 +128,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "trocr", "vit", "xglm", - # "xlnet", - # TODO: add support for them as it should be quite easy to do so (small blocking issues). + # "xlnet", ] _REGULAR_SUPPORTED_MODELS = [] @@ -562,10 +562,8 @@ class HFProxy(Proxy): return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) def __contains__(self, key): - # To handle cases such as : - # `"some_key" in kwargs` - if self.node.op == "placeholder": - return False + if hasattr(self, "_metadata") and self._metadata is not None: + return key in self._metadata return super().__contains__(key) @@ -905,6 +903,9 @@ class HFTracer(Tracer): inputs.update(self._generate_dummy_input(root, input_name, shape)) concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()} + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names: + concrete_metas[f"**{param.name}"] = {} self.meta_args = concrete_metas self.patched_torch_methods = { target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH @@ -933,18 +934,15 @@ class HFTracer(Tracer): node.type = torch.Tensor # It is a concrete arg so it is not used and should be removed. else: - if hasattr(torch.fx._symbolic_trace, "_assert_is_none"): - # Newer versions of torch.fx emit an assert statement - # for concrete arguments; delete those before we delete - # the concrete arg. - to_delete = [] - for user in node.users: - if user.target == torch.fx._symbolic_trace._assert_is_none: - to_delete.append(user) - for user in to_delete: - self.graph.erase_node(user) + to_visit = [node] + to_delete = collections.OrderedDict() + while to_visit: + n = to_visit.pop(0) + to_delete[n] = None + to_visit += list(n.users.keys()) - self.graph.erase_node(node) + for user in reversed(to_delete.keys()): + self.graph.erase_node(user) # TODO: solves GraphModule creation. # Without this, return type annotation "Tuple" is causing code execution failure. diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index b0307b922c..4570cb7673 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -320,7 +320,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ) all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else () - fx_compatible = False + fx_compatible = True test_missing_keys = False test_pruning = False test_torchscript = True # torch.autograd functions seems to be not supported