Fx support for multiple model architectures (#17393)
* Support for Bart and LayoutLM, and partial support for XLNet * Support for mbart * A lot of new models supported * Support for other models * LayoutLM fix * Use strings instead of classes
This commit is contained in:
@@ -738,17 +738,32 @@ class ModelTesterMixin:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"input_features",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"pixel_values",
|
||||
"bbox",
|
||||
"input_features",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
@@ -761,7 +776,7 @@ class ModelTesterMixin:
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user