From ba743700f49608357d4618cdb658feafd3fa66e6 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 8 Jul 2024 13:17:28 +0200 Subject: [PATCH] transformers.fx.symbolic_trace supports inputs_embeds (#31574) * symbolic trace supports inputs_embeds * fix test? * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/fx.py | 7 +++++++ tests/test_modeling_common.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index c3687c035c..0aa296e705 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -995,6 +995,13 @@ class HFTracer(Tracer): inputs_dict[input_name] = torch.zeros( *shape, model.config.input_feat_per_channel, dtype=torch.float, device=device ) + elif "inputs_embeds" in input_name: + batch_size = shape[0] + sequence_length = shape[-1] + + inputs_dict[input_name] = torch.zeros( + batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device + ) elif "visual_feats" in input_name: inputs_dict[input_name] = torch.zeros( shape diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cdc173cc64..7c3bc3dc9e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1158,6 +1158,7 @@ class ModelTesterMixin: "input_features", "input_ids", "input_values", + "inputs_embeds", "pixel_values", "token_type_ids", "visual_feats", @@ -1214,16 +1215,27 @@ class ModelTesterMixin: (past_mask, inputs_to_test[1]["attention_mask"]), dim=1 ) + if "inputs_embeds" in inspect.signature(model.forward).parameters: + inputs_to_test.append( + { + "inputs_embeds": torch.rand( + 2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device + ) + } + ) + for inps in inputs_to_test: filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names} - input_names = list(filtered_inputs.keys()) + input_names_to_trace = list(filtered_inputs.keys()) if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) + model.config.use_cache = "past_key_values" in input_names_to_trace + + traced_model = symbolic_trace(model, input_names_to_trace) with torch.no_grad(): traced_output = traced_model(**filtered_inputs)