Allow # Ignore copy (#27328)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -95,13 +95,156 @@ class BertCopyModel(BertCopyPreTrainedModel):
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 3
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_roberta_not_ignored(self, c):
|
||||
return 2
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 5
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
EXPECTED_REPLACED_CODE = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
"""
|
||||
|
||||
|
||||
def replace_in_file(filename, old, new):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
with open(filename, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
@@ -117,11 +260,18 @@ def create_tmp_repo(tmp_dir):
|
||||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
|
||||
models = {
|
||||
"bert": MOCK_BERT_CODE,
|
||||
"bertcopy": MOCK_BERT_COPY_CODE,
|
||||
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
|
||||
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
|
||||
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
|
||||
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
|
||||
}
|
||||
for model, code in models.items():
|
||||
model_subdir = model_dir / model
|
||||
model_subdir.mkdir(exist_ok=True)
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
@@ -176,11 +326,47 @@ class CopyCheckTester(unittest.TestCase):
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
||||
|
||||
diffs = is_copy_consistent(file_to_check, overwrite=True)
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
||||
|
||||
def test_is_copy_consistent_with_ignored_match(self):
|
||||
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
def test_is_copy_consistent_with_ignored_no_match(self):
|
||||
path_to_check = [
|
||||
"src",
|
||||
"transformers",
|
||||
"models",
|
||||
"dummy_roberta_no_match",
|
||||
"modeling_dummy_roberta_no_match.py",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check with an inconsistency
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
|
||||
# (which has a leading `\n`.)
|
||||
self.assertEqual(
|
||||
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
|
||||
)
|
||||
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
|
||||
|
||||
def test_convert_to_localized_md(self):
|
||||
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user