Document check copies (#25291)
* Document check copies better and add tests * Include header in check for copies * Manual fixes * Try autofix * Fixes * Clean tests * Finalize doc * Remove debug print * More fixes
This commit is contained in:
@@ -13,19 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import black
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
import check_copies # noqa: E402
|
||||
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
|
||||
|
||||
|
||||
# This is the reference code that will be used in the tests.
|
||||
@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config):
|
||||
return hidden_states
|
||||
"""
|
||||
|
||||
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
|
||||
|
||||
def bert_function(x):
|
||||
return x
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.bert = BertEncoder(config)
|
||||
|
||||
@add_docstring(BERT_DOCSTRING)
|
||||
def forward(self, x):
|
||||
return self.bert(x)
|
||||
"""
|
||||
|
||||
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.bert_function
|
||||
def bert_copy_function(x):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention
|
||||
class BertCopyAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
|
||||
class BertCopyModel(BertCopyPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.bertcopy = BertCopyEncoder(config)
|
||||
|
||||
@add_docstring(BERTCOPY_DOCSTRING)
|
||||
def forward(self, x):
|
||||
return self.bertcopy(x)
|
||||
"""
|
||||
|
||||
|
||||
def replace_in_file(filename, old, new):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def create_tmp_repo(tmp_dir):
|
||||
"""
|
||||
Creates a mock repository in a temporary folder for testing.
|
||||
"""
|
||||
tmp_dir = Path(tmp_dir)
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True)
|
||||
|
||||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
|
||||
for model, code in models.items():
|
||||
model_subdir = model_dir / model
|
||||
model_subdir.mkdir(exist_ok=True)
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_transformer_repo_path(new_folder):
|
||||
"""
|
||||
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
|
||||
"""
|
||||
old_repo_path = check_copies.REPO_PATH
|
||||
old_doc_path = check_copies.PATH_TO_DOCS
|
||||
old_transformer_path = check_copies.TRANSFORMERS_PATH
|
||||
repo_path = Path(new_folder).resolve()
|
||||
check_copies.REPO_PATH = str(repo_path)
|
||||
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
|
||||
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
check_copies.REPO_PATH = old_repo_path
|
||||
check_copies.PATH_TO_DOCS = old_doc_path
|
||||
check_copies.TRANSFORMERS_PATH = old_transformer_path
|
||||
|
||||
|
||||
class CopyCheckTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.transformer_dir = tempfile.mkdtemp()
|
||||
os.makedirs(os.path.join(self.transformer_dir, "models/bert/"))
|
||||
check_copies.TRANSFORMER_PATH = self.transformer_dir
|
||||
shutil.copy(
|
||||
os.path.join(git_repo_path, "src/transformers/models/bert/modeling_bert.py"),
|
||||
os.path.join(self.transformer_dir, "models/bert/modeling_bert.py"),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
check_copies.TRANSFORMER_PATH = "src/transformers"
|
||||
shutil.rmtree(self.transformer_dir)
|
||||
|
||||
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
|
||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||
if overwrite_result is not None:
|
||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||
code = black.format_str(code, mode=mode)
|
||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||
with open(fname, "w", newline="\n") as f:
|
||||
f.write(code)
|
||||
if overwrite_result is None:
|
||||
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
|
||||
else:
|
||||
check_copies.is_copy_consistent(f.name, overwrite=True)
|
||||
with open(fname, "r") as f:
|
||||
self.assertTrue(f.read(), expected)
|
||||
|
||||
def test_find_code_in_transformers(self):
|
||||
code = check_copies.find_code_in_transformers("models.bert.modeling_bert.BertLMPredictionHead")
|
||||
self.assertEqual(code, REFERENCE_CODE)
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
|
||||
|
||||
reference_code = (
|
||||
"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
|
||||
)
|
||||
self.assertEqual(code, reference_code)
|
||||
|
||||
def test_is_copy_consistent(self):
|
||||
# Base copy consistency
|
||||
self.check_copy_consistency(
|
||||
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
|
||||
"BertLMPredictionHead",
|
||||
REFERENCE_CODE + "\n",
|
||||
)
|
||||
path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
# With no empty line at the end
|
||||
self.check_copy_consistency(
|
||||
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
|
||||
"BertLMPredictionHead",
|
||||
REFERENCE_CODE,
|
||||
)
|
||||
# Base check with an inconsistency
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
|
||||
# Copy consistency with rename
|
||||
self.check_copy_consistency(
|
||||
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
|
||||
"TestModelLMPredictionHead",
|
||||
re.sub("Bert", "TestModel", REFERENCE_CODE),
|
||||
)
|
||||
replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
||||
|
||||
# Copy consistency with a really long name
|
||||
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
|
||||
self.check_copy_consistency(
|
||||
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
|
||||
f"{long_class_name}LMPredictionHead",
|
||||
re.sub("Bert", long_class_name, REFERENCE_CODE),
|
||||
)
|
||||
diffs = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
# Copy consistency with overwrite
|
||||
self.check_copy_consistency(
|
||||
"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
|
||||
"TestModelLMPredictionHead",
|
||||
REFERENCE_CODE,
|
||||
overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
|
||||
)
|
||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
||||
|
||||
def test_convert_to_localized_md(self):
|
||||
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
||||
@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase):
|
||||
" Christopher D. Manning 发布。\n"
|
||||
)
|
||||
|
||||
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
md_list, localized_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
self.assertFalse(num_models_equal)
|
||||
self.assertEqual(converted_md_list, converted_md_list_sample)
|
||||
|
||||
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
md_list, converted_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase):
|
||||
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
|
||||
)
|
||||
|
||||
num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user