Improve forward signature test (#27729)
* First draft * Extend test_forward_signature * Update tests/test_modeling_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Revert suggestion --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -542,6 +542,12 @@ class ModelTesterMixin:
|
|||||||
else ["encoder_outputs"]
|
else ["encoder_outputs"]
|
||||||
)
|
)
|
||||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
|
||||||
|
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
|
||||||
|
self.assertListEqual(arg_names, expected_arg_names)
|
||||||
|
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
|
||||||
|
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
|
||||||
|
self.assertListEqual(arg_names, expected_arg_names)
|
||||||
else:
|
else:
|
||||||
expected_arg_names = [model.main_input_name]
|
expected_arg_names = [model.main_input_name]
|
||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|||||||
Reference in New Issue
Block a user