Add a main_input_name attribute to all models (#14803)

* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-12-20 11:19:08 -05:00
committed by GitHub
parent 0940e9b242
commit 33f36c869f
32 changed files with 61 additions and 5 deletions

View File

@@ -1315,6 +1315,13 @@ class ModelTesterMixin:
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_correct_missing_keys(self):
if not self.test_missing_keys:
return

View File

@@ -778,6 +778,13 @@ class FlaxModelTesterMixin:
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "__call__"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_headmasking(self):
if not self.test_head_masking:
return

View File

@@ -1183,6 +1183,13 @@ class TFModelTesterMixin:
else:
new_model_without_prefix(input_ids)
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "call"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []