Fx support for Deberta-v[1-2], Hubert and LXMERT (#17539)

* Support for deberta and deberta-v2

* Support for LXMert

* Support for Hubert

* Fix for pt1.11

* Trigger CI
This commit is contained in:
Michael Benayoun
2022-06-07 18:05:20 +02:00
committed by GitHub
parent 3cab90279f
commit 5c8f601007
10 changed files with 221 additions and 29 deletions

View File

@@ -740,11 +740,12 @@ class ModelTesterMixin:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
@@ -758,12 +759,15 @@ class ModelTesterMixin:
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
@@ -781,10 +785,17 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs)
if (
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):