Fx support for Deberta-v[1-2], Hubert and LXMERT (#17539)
* Support for deberta and deberta-v2 * Support for LXMert * Support for Hubert * Fix for pt1.11 * Trigger CI
This commit is contained in:
@@ -740,11 +740,12 @@ class ModelTesterMixin:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
@@ -758,12 +759,15 @@ class ModelTesterMixin:
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"pixel_values",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
@@ -781,10 +785,17 @@ class ModelTesterMixin:
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if (
|
||||
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
||||
and not hasattr(model.config, "problem_type")
|
||||
or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
|
||||
Reference in New Issue
Block a user