Check decorator order (#7326)
* Check decorator order * Adapt for parametrized decorators * Fix typos
This commit is contained in:
@@ -185,8 +185,8 @@ class BertGenerationTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1419,8 +1419,8 @@ class TokenizerTesterMixin:
|
||||
# add pad_token_id to pass subsequent tests
|
||||
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
|
||||
@@ -1470,8 +1470,8 @@ class TokenizerTesterMixin:
|
||||
# model(**encoded_sequence_fast)
|
||||
# model(**batch_encoded_sequence_fast)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_encode_plus_sent_to_model(self):
|
||||
from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING
|
||||
|
||||
@@ -1505,8 +1505,8 @@ class TokenizerTesterMixin:
|
||||
model(batch_encoded_sequence)
|
||||
|
||||
# TODO: Check if require_torch is the best to test for numpy here ... Maybe move to require_flax when available
|
||||
@slow
|
||||
@require_torch
|
||||
@slow
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
|
||||
|
||||
|
||||
@@ -230,8 +230,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
|
||||
|
||||
@@ -273,9 +273,46 @@ def check_all_models_are_documented():
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
def check_decorator_order(filename):
|
||||
""" Check that in the test file `filename` the slow decorator is always last."""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
decorator_before = None
|
||||
errors = []
|
||||
for i, line in enumerate(lines):
|
||||
search = _re_decorator.search(line)
|
||||
if search is not None:
|
||||
decorator_name = search.groups()[0]
|
||||
if decorator_before is not None and decorator_name.startswith("parameterized"):
|
||||
errors.append(i)
|
||||
decorator_before = decorator_name
|
||||
elif decorator_before is not None:
|
||||
decorator_before = None
|
||||
return errors
|
||||
|
||||
|
||||
def check_all_decorator_order():
|
||||
""" Check that in all test files, the slow decorator is always last."""
|
||||
errors = []
|
||||
for fname in os.listdir(PATH_TO_TESTS):
|
||||
if fname.endswith(".py"):
|
||||
filename = os.path.join(PATH_TO_TESTS, fname)
|
||||
new_errors = check_decorator_order(filename)
|
||||
errors += [f"- {filename}, line {i}" for i in new_errors]
|
||||
if len(errors) > 0:
|
||||
msg = "\n".join(errors)
|
||||
raise ValueError(
|
||||
f"The parameterized decorator (and its variants) should always be first, but this is not the case in the following files:\n{msg}"
|
||||
)
|
||||
|
||||
|
||||
def check_repo_quality():
|
||||
""" Check all models are properly tested and documented."""
|
||||
print("Checking all models are properly tested.")
|
||||
check_all_decorator_order()
|
||||
check_all_models_are_tested()
|
||||
print("Checking all models are properly documented.")
|
||||
check_all_models_are_documented()
|
||||
|
||||
Reference in New Issue
Block a user