Document check copies (#25291)

* Document check copies better and add tests

* Include header in check for copies

* Manual fixes

* Try autofix

* Fixes

* Clean tests

* Finalize doc

* Remove debug print

* More fixes
This commit is contained in:
Sylvain Gugger
2023-08-04 14:56:29 +02:00
committed by GitHub
parent 29f04002e6
commit f0fd73a2de
51 changed files with 382 additions and 166 deletions

View File

@@ -13,19 +13,19 @@
# limitations under the License.
import os
import re
import shutil
import sys
import tempfile
import unittest
import black
from contextlib import contextmanager
from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
# This is the reference code that will be used in the tests.
@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config):
return hidden_states
"""
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
def bert_function(x):
return x
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bert = BertEncoder(config)
@add_docstring(BERT_DOCSTRING)
def forward(self, x):
return self.bert(x)
"""
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
# Copied from transformers.models.bert.modeling_bert.bert_function
def bert_copy_function(x):
return x
# Copied from transformers.models.bert.modeling_bert.BertAttention
class BertCopyAttention(nn.Module):
def __init__(self, config):
super().__init__()
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
class BertCopyModel(BertCopyPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bertcopy = BertCopyEncoder(config)
@add_docstring(BERTCOPY_DOCSTRING)
def forward(self, x):
return self.bertcopy(x)
"""
def replace_in_file(filename, old, new):
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace(old, new)
with open(filename, "w", encoding="utf-8") as f:
f.write(content)
def create_tmp_repo(tmp_dir):
"""
Creates a mock repository in a temporary folder for testing.
"""
tmp_dir = Path(tmp_dir)
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True)
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
for model, code in models.items():
model_subdir = model_dir / model
model_subdir.mkdir(exist_ok=True)
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
f.write(code)
@contextmanager
def patch_transformer_repo_path(new_folder):
"""
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
"""
old_repo_path = check_copies.REPO_PATH
old_doc_path = check_copies.PATH_TO_DOCS
old_transformer_path = check_copies.TRANSFORMERS_PATH
repo_path = Path(new_folder).resolve()
check_copies.REPO_PATH = str(repo_path)
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
try:
yield
finally:
check_copies.REPO_PATH = old_repo_path
check_copies.PATH_TO_DOCS = old_doc_path
check_copies.TRANSFORMERS_PATH = old_transformer_path
class CopyCheckTester(unittest.TestCase):
def setUp(self):
self.transformer_dir = tempfile.mkdtemp()
os.makedirs(os.path.join(self.transformer_dir, "models/bert/"))
check_copies.TRANSFORMER_PATH = self.transformer_dir
shutil.copy(
os.path.join(git_repo_path, "src/transformers/models/bert/modeling_bert.py"),
os.path.join(self.transformer_dir, "models/bert/modeling_bert.py"),
)
def tearDown(self):
check_copies.TRANSFORMER_PATH = "src/transformers"
shutil.rmtree(self.transformer_dir)
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
code = black.format_str(code, mode=mode)
fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w", newline="\n") as f:
f.write(code)
if overwrite_result is None:
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
else:
check_copies.is_copy_consistent(f.name, overwrite=True)
with open(fname, "r") as f:
self.assertTrue(f.read(), expected)
def test_find_code_in_transformers(self):
code = check_copies.find_code_in_transformers("models.bert.modeling_bert.BertLMPredictionHead")
self.assertEqual(code, REFERENCE_CODE)
with tempfile.TemporaryDirectory() as tmp_folder:
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
reference_code = (
"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
)
self.assertEqual(code, reference_code)
def test_is_copy_consistent(self):
# Base copy consistency
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE + "\n",
)
path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
# With no empty line at the end
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE,
)
# Base check with an inconsistency
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
# Copy consistency with rename
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
re.sub("Bert", "TestModel", REFERENCE_CODE),
)
replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
# Copy consistency with a really long name
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
self.check_copy_consistency(
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
f"{long_class_name}LMPredictionHead",
re.sub("Bert", long_class_name, REFERENCE_CODE),
)
diffs = is_copy_consistent(file_to_check, overwrite=True)
# Copy consistency with overwrite
self.check_copy_consistency(
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
REFERENCE_CODE,
overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
)
with open(file_to_check, "r", encoding="utf-8") as f:
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase):
" Christopher D. Manning 发布。\n"
)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
)
self.assertFalse(num_models_equal)
self.assertEqual(converted_md_list, converted_md_list_sample)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, converted_md_list, localized_readme["format_model_list"]
)
@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase):
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
num_models_equal, converted_md_list = convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
)