transformers.fx.symbolic_trace supports inputs_embeds (#31574)
* symbolic trace supports inputs_embeds * fix test? * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -995,6 +995,13 @@ class HFTracer(Tracer):
|
|||||||
inputs_dict[input_name] = torch.zeros(
|
inputs_dict[input_name] = torch.zeros(
|
||||||
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||||||
)
|
)
|
||||||
|
elif "inputs_embeds" in input_name:
|
||||||
|
batch_size = shape[0]
|
||||||
|
sequence_length = shape[-1]
|
||||||
|
|
||||||
|
inputs_dict[input_name] = torch.zeros(
|
||||||
|
batch_size, sequence_length, model.config.hidden_size, dtype=torch.float, device=device
|
||||||
|
)
|
||||||
elif "visual_feats" in input_name:
|
elif "visual_feats" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(
|
inputs_dict[input_name] = torch.zeros(
|
||||||
shape
|
shape
|
||||||
|
|||||||
@@ -1158,6 +1158,7 @@ class ModelTesterMixin:
|
|||||||
"input_features",
|
"input_features",
|
||||||
"input_ids",
|
"input_ids",
|
||||||
"input_values",
|
"input_values",
|
||||||
|
"inputs_embeds",
|
||||||
"pixel_values",
|
"pixel_values",
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
"visual_feats",
|
"visual_feats",
|
||||||
@@ -1214,16 +1215,27 @@ class ModelTesterMixin:
|
|||||||
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
|
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "inputs_embeds" in inspect.signature(model.forward).parameters:
|
||||||
|
inputs_to_test.append(
|
||||||
|
{
|
||||||
|
"inputs_embeds": torch.rand(
|
||||||
|
2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for inps in inputs_to_test:
|
for inps in inputs_to_test:
|
||||||
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
|
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
|
||||||
input_names = list(filtered_inputs.keys())
|
input_names_to_trace = list(filtered_inputs.keys())
|
||||||
|
|
||||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||||
):
|
):
|
||||||
model.config.problem_type = "single_label_classification"
|
model.config.problem_type = "single_label_classification"
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
model.config.use_cache = "past_key_values" in input_names_to_trace
|
||||||
|
|
||||||
|
traced_model = symbolic_trace(model, input_names_to_trace)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user