Allow # Ignore copy (#27328)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,9 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
|||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
|
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
|
||||||
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
# Ignore copy
|
||||||
tokenizer_class = LongformerTokenizer
|
tokenizer_class = LongformerTokenizer
|
||||||
test_slow_tokenizer = True
|
test_slow_tokenizer = True
|
||||||
rust_tokenizer_class = LongformerTokenizerFast
|
rust_tokenizer_class = LongformerTokenizerFast
|
||||||
@@ -71,23 +73,19 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer
|
|
||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer
|
|
||||||
def get_rust_tokenizer(self, **kwargs):
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts
|
|
||||||
def get_input_output_texts(self, tokenizer):
|
def get_input_output_texts(self, tokenizer):
|
||||||
input_text = "lower newer"
|
input_text = "lower newer"
|
||||||
output_text = "lower newer"
|
output_text = "lower newer"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
text = "lower newer"
|
text = "lower newer"
|
||||||
@@ -99,7 +97,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer
|
|
||||||
def longformer_dict_integration_testing(self):
|
def longformer_dict_integration_testing(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
@@ -110,7 +107,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096
|
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
|
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
|
||||||
|
|
||||||
@@ -130,7 +126,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
assert encoded_sentence == encoded_text_from_decode
|
assert encoded_sentence == encoded_text_from_decode
|
||||||
assert encoded_pair == encoded_pair_from_decode
|
assert encoded_pair == encoded_pair_from_decode
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding
|
|
||||||
def test_space_encoding(self):
|
def test_space_encoding(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
@@ -171,11 +166,9 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
|
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
|
||||||
self.assertNotEqual(first_char, space_encoding)
|
self.assertNotEqual(first_char, space_encoding)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs
|
|
||||||
def test_pretokenized_inputs(self):
|
def test_pretokenized_inputs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens
|
|
||||||
def test_embeded_special_tokens(self):
|
def test_embeded_special_tokens(self):
|
||||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
@@ -208,7 +201,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args
|
|
||||||
def test_change_add_prefix_space_and_trim_offsets_args(self):
|
def test_change_add_prefix_space_and_trim_offsets_args(self):
|
||||||
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
|
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
|
||||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||||
@@ -223,7 +215,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
|
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
|
||||||
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
|
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
|
||||||
|
|
||||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments
|
|
||||||
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
|
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
|
||||||
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
|
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
|
||||||
# `trim_offsets`
|
# `trim_offsets`
|
||||||
|
|||||||
@@ -95,13 +95,156 @@ class BertCopyModel(BertCopyPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_DUMMY_BERT_CODE_MATCH = """
|
||||||
|
class BertDummyModel:
|
||||||
|
attr_1 = 1
|
||||||
|
attr_2 = 2
|
||||||
|
|
||||||
|
def __init__(self, a=1, b=2):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||||
|
def forward(self, c):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def existing_common(self, c):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def existing_diff_to_be_ignored(self, c):
|
||||||
|
return 9
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
|
||||||
|
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||||
|
class RobertaBertDummyModel:
|
||||||
|
|
||||||
|
attr_1 = 1
|
||||||
|
attr_2 = 2
|
||||||
|
|
||||||
|
def __init__(self, a=1, b=2):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def only_in_roberta_to_be_ignored(self, c):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||||
|
def forward(self, c):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def existing_common(self, c):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def existing_diff_to_be_ignored(self, c):
|
||||||
|
return 6
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
|
||||||
|
class BertDummyModel:
|
||||||
|
attr_1 = 1
|
||||||
|
attr_2 = 2
|
||||||
|
|
||||||
|
def __init__(self, a=1, b=2):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||||
|
def forward(self, c):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def only_in_bert(self, c):
|
||||||
|
return 7
|
||||||
|
|
||||||
|
def existing_common(self, c):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def existing_diff_not_ignored(self, c):
|
||||||
|
return 8
|
||||||
|
|
||||||
|
def existing_diff_to_be_ignored(self, c):
|
||||||
|
return 9
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
|
||||||
|
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||||
|
class RobertaBertDummyModel:
|
||||||
|
|
||||||
|
attr_1 = 1
|
||||||
|
attr_2 = 3
|
||||||
|
|
||||||
|
def __init__(self, a=1, b=2):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def only_in_roberta_to_be_ignored(self, c):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||||
|
def forward(self, c):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def only_in_roberta_not_ignored(self, c):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
def existing_common(self, c):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def existing_diff_not_ignored(self, c):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def existing_diff_to_be_ignored(self, c):
|
||||||
|
return 6
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_REPLACED_CODE = """
|
||||||
|
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||||
|
class RobertaBertDummyModel:
|
||||||
|
attr_1 = 1
|
||||||
|
attr_2 = 2
|
||||||
|
|
||||||
|
def __init__(self, a=1, b=2):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
|
||||||
|
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||||
|
def forward(self, c):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def only_in_bert(self, c):
|
||||||
|
return 7
|
||||||
|
|
||||||
|
def existing_common(self, c):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def existing_diff_not_ignored(self, c):
|
||||||
|
return 8
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def existing_diff_to_be_ignored(self, c):
|
||||||
|
return 6
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def only_in_roberta_to_be_ignored(self, c):
|
||||||
|
return 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def replace_in_file(filename, old, new):
|
def replace_in_file(filename, old, new):
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
content = content.replace(old, new)
|
content = content.replace(old, new)
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8", newline="\n") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
@@ -117,11 +260,18 @@ def create_tmp_repo(tmp_dir):
|
|||||||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
|
models = {
|
||||||
|
"bert": MOCK_BERT_CODE,
|
||||||
|
"bertcopy": MOCK_BERT_COPY_CODE,
|
||||||
|
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
|
||||||
|
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
|
||||||
|
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
|
||||||
|
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
|
||||||
|
}
|
||||||
for model, code in models.items():
|
for model, code in models.items():
|
||||||
model_subdir = model_dir / model
|
model_subdir = model_dir / model
|
||||||
model_subdir.mkdir(exist_ok=True)
|
model_subdir.mkdir(exist_ok=True)
|
||||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
|
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
|
|
||||||
|
|
||||||
@@ -176,11 +326,47 @@ class CopyCheckTester(unittest.TestCase):
|
|||||||
diffs = is_copy_consistent(file_to_check)
|
diffs = is_copy_consistent(file_to_check)
|
||||||
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
||||||
|
|
||||||
diffs = is_copy_consistent(file_to_check, overwrite=True)
|
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||||
|
|
||||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||||
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
||||||
|
|
||||||
|
def test_is_copy_consistent_with_ignored_match(self):
|
||||||
|
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
# Base check
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||||
|
diffs = is_copy_consistent(file_to_check)
|
||||||
|
self.assertEqual(diffs, [])
|
||||||
|
|
||||||
|
def test_is_copy_consistent_with_ignored_no_match(self):
|
||||||
|
path_to_check = [
|
||||||
|
"src",
|
||||||
|
"transformers",
|
||||||
|
"models",
|
||||||
|
"dummy_roberta_no_match",
|
||||||
|
"modeling_dummy_roberta_no_match.py",
|
||||||
|
]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
# Base check with an inconsistency
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||||
|
|
||||||
|
diffs = is_copy_consistent(file_to_check)
|
||||||
|
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
|
||||||
|
# (which has a leading `\n`.)
|
||||||
|
self.assertEqual(
|
||||||
|
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||||
|
|
||||||
|
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||||
|
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
|
||||||
|
|
||||||
def test_convert_to_localized_md(self):
|
def test_convert_to_localized_md(self):
|
||||||
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ import glob
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Optional, Tuple
|
from collections import OrderedDict
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers.utils import direct_transformers_import
|
from transformers.utils import direct_transformers_import
|
||||||
|
|
||||||
@@ -125,13 +126,213 @@ LOCALIZED_READMES = {
|
|||||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_definition_header_ending_line(line: str) -> bool:
|
||||||
|
# Helper function. Returns `True` if `line` is the end parenthesis of a class/function definition
|
||||||
|
return re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||||
|
|
||||||
|
|
||||||
def _should_continue(line: str, indent: str) -> bool:
|
def _should_continue(line: str, indent: str) -> bool:
|
||||||
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
|
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
|
||||||
# function definition
|
# class/function definition
|
||||||
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line)
|
||||||
|
|
||||||
|
|
||||||
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
def _sanity_check_splits(splits_1, splits_2, is_class):
|
||||||
|
"""Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match.
|
||||||
|
|
||||||
|
For the case of `class`, they must be of one of the following 3 cases:
|
||||||
|
|
||||||
|
- a single block without name:
|
||||||
|
|
||||||
|
class foo:
|
||||||
|
a = 1
|
||||||
|
|
||||||
|
- a consecutive sequence of (1 or more) blocks with name
|
||||||
|
|
||||||
|
class foo:
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
- a block without name, followed by a consecutive sequence of (1 or more) blocks with name
|
||||||
|
|
||||||
|
class foo:
|
||||||
|
a = 1
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
return None
|
||||||
|
|
||||||
|
The 2 code snippets that give `splits_1` and `splits_2` have to be in the same case to pass this check, but the
|
||||||
|
number of blocks with name in the consecutive sequence is not taken into account.
|
||||||
|
|
||||||
|
For the case of `function or method`, we don't require it to be in one of the above 3 cases. However, the structure
|
||||||
|
of`splits_1` and `splits_2` have to match exactly. In particular, the number of blocks with name in a consecutive
|
||||||
|
sequence is taken into account.
|
||||||
|
"""
|
||||||
|
block_names_1 = []
|
||||||
|
block_names_2 = []
|
||||||
|
|
||||||
|
for block in splits_1[1:]:
|
||||||
|
if block[0].startswith("_block_without_name_"):
|
||||||
|
block_names_1.append("block_without_name")
|
||||||
|
elif not block[0].startswith("_empty_block_") and (
|
||||||
|
not is_class or len(block_names_1) == 0 or block_names_1[-1].startswith("block_without_name")
|
||||||
|
):
|
||||||
|
block_names_1.append("block_with_name")
|
||||||
|
|
||||||
|
for block in splits_2[1:]:
|
||||||
|
if block[0].startswith("_block_without_name_"):
|
||||||
|
block_names_2.append("block_without_name")
|
||||||
|
elif not block[0].startswith("_empty_block_") and (
|
||||||
|
not is_class or len(block_names_2) == 0 or block_names_2[-1].startswith("block_without_name")
|
||||||
|
):
|
||||||
|
block_names_2.append("block_with_name")
|
||||||
|
|
||||||
|
if is_class:
|
||||||
|
if block_names_1 not in [
|
||||||
|
["block_without_name"],
|
||||||
|
["block_with_name"],
|
||||||
|
["block_without_name", "block_with_name"],
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
"For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if block_names_1 != block_names_2:
|
||||||
|
raise ValueError("The structures in the 2 code blocks differ.")
|
||||||
|
|
||||||
|
|
||||||
|
def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
|
||||||
|
"""
|
||||||
|
Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (`List[str]`):
|
||||||
|
The source code, represented by a list of lines.
|
||||||
|
start_index (`int`):
|
||||||
|
The starting index of the target class/func block.
|
||||||
|
indent (`int`):
|
||||||
|
The indent of the class/func body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`: The index of the block's ending line plus by 1 (i.e. exclusive).
|
||||||
|
"""
|
||||||
|
indent = " " * indent
|
||||||
|
# enter the block body
|
||||||
|
line_index = start_index + 1
|
||||||
|
|
||||||
|
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||||
|
line_index += 1
|
||||||
|
# Clean up empty lines at the end (if any).
|
||||||
|
while len(lines[line_index - 1]) <= 1:
|
||||||
|
line_index -= 1
|
||||||
|
|
||||||
|
return line_index
|
||||||
|
|
||||||
|
|
||||||
|
def split_code_into_blocks(
|
||||||
|
lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
"""
|
||||||
|
Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*.
|
||||||
|
|
||||||
|
The block's header is included as the first element. The contiguous regions (without empty lines) that are not
|
||||||
|
inside any inner block are included as blocks. The contiguous regions of empty lines that are not inside any inner
|
||||||
|
block are also included as (dummy) blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (`List[str]`):
|
||||||
|
The source code, represented by a list of lines.
|
||||||
|
start_index (`int`):
|
||||||
|
The starting index of the target class/func block.
|
||||||
|
end_index (`int`):
|
||||||
|
The ending index of the target class/func block.
|
||||||
|
indent (`int`):
|
||||||
|
The indent of the class/func body.
|
||||||
|
backtrace (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to include the lines before the inner class/func block's header (e.g. comments, decorators,
|
||||||
|
etc.) until an empty line is encountered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Tuple[str, int, int]]`: A list of elements with the form `(block_name, start_index, end_index)`.
|
||||||
|
"""
|
||||||
|
splits = []
|
||||||
|
# `indent - 4` is the indent level of the target class/func header
|
||||||
|
target_block_name = re.search(rf"^{' ' * (indent - 4)}((class|def)\s+\S+)(\(|\:)", lines[start_index]).groups()[0]
|
||||||
|
|
||||||
|
# from now on, the `block` means inner blocks unless explicitly specified
|
||||||
|
indent_str = " " * indent
|
||||||
|
block_without_name_idx = 0
|
||||||
|
empty_block_idx = 0
|
||||||
|
|
||||||
|
# Find the lines for the definition header
|
||||||
|
index = start_index
|
||||||
|
if "(" in lines[start_index] and "):" not in lines[start_index] in lines[start_index]:
|
||||||
|
while index < end_index:
|
||||||
|
if _is_definition_header_ending_line(lines[index]):
|
||||||
|
break
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
# the first line outside the definition header
|
||||||
|
index += 1
|
||||||
|
splits.append((target_block_name, start_index, index))
|
||||||
|
|
||||||
|
block_start_index, prev_block_end_index = index, index
|
||||||
|
while index < end_index:
|
||||||
|
# if found, it will be an inner block
|
||||||
|
block_found = re.search(rf"^{indent_str}((class|def)\s+\S+)(\(|\:)", lines[index])
|
||||||
|
if block_found:
|
||||||
|
name = block_found.groups()[0]
|
||||||
|
|
||||||
|
block_end_index = find_block_end(lines, index, indent + 4)
|
||||||
|
|
||||||
|
# backtrace to include the lines before the found block's definition header (e.g. comments, decorators,
|
||||||
|
# etc.) until an empty line is encountered.
|
||||||
|
block_start_index = index
|
||||||
|
if index > prev_block_end_index and backtrace:
|
||||||
|
idx = index - 1
|
||||||
|
for idx in range(index - 1, prev_block_end_index - 2, -1):
|
||||||
|
if not (len(lines[idx].strip()) > 0 and lines[idx].startswith(indent_str)):
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
if idx < index:
|
||||||
|
block_start_index = idx
|
||||||
|
|
||||||
|
# between the current found block and the previous found block
|
||||||
|
if block_start_index > prev_block_end_index:
|
||||||
|
# give it a dummy name
|
||||||
|
if len("".join(lines[prev_block_end_index:block_start_index]).strip()) == 0:
|
||||||
|
prev_block_name = f"_empty_block_{empty_block_idx}"
|
||||||
|
empty_block_idx += 1
|
||||||
|
else:
|
||||||
|
prev_block_name = f"_block_without_name_{block_without_name_idx}"
|
||||||
|
block_without_name_idx += 1
|
||||||
|
# Add it as a block
|
||||||
|
splits.append((prev_block_name, prev_block_end_index, block_start_index))
|
||||||
|
|
||||||
|
# Add the current found block
|
||||||
|
splits.append((name, block_start_index, block_end_index))
|
||||||
|
prev_block_end_index = block_end_index
|
||||||
|
index = block_end_index - 1
|
||||||
|
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
if index > prev_block_end_index:
|
||||||
|
if len("".join(lines[prev_block_end_index:index]).strip()) == 0:
|
||||||
|
prev_block_name = f"_empty_block_{empty_block_idx}"
|
||||||
|
else:
|
||||||
|
prev_block_name = f"_block_without_name_{block_without_name_idx}"
|
||||||
|
splits.append((prev_block_name, prev_block_end_index, index))
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
def find_code_in_transformers(
|
||||||
|
object_name: str, base_path: str = None, return_indices: bool = False
|
||||||
|
) -> Union[str, Tuple[List[str], int, int]]:
|
||||||
"""
|
"""
|
||||||
Find and return the source code of an object.
|
Find and return the source code of an object.
|
||||||
|
|
||||||
@@ -140,9 +341,15 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
|||||||
The name of the object we want the source code of.
|
The name of the object we want the source code of.
|
||||||
base_path (`str`, *optional*):
|
base_path (`str`, *optional*):
|
||||||
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
|
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
|
||||||
|
return_indices(`bool`, *optional*, defaults to `False`):
|
||||||
|
If `False`, will only return the code (as a string), otherwise it will also return the whole lines of the
|
||||||
|
file where the object specified by `object_name` is defined, together the start/end indices of the block in
|
||||||
|
the file that defines the object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The source code of the object.
|
`Union[str, Tuple[List[str], int, int]]`: If `return_indices=False`, only the source code of the object will be
|
||||||
|
returned. Otherwise, it also returns the whole lines of the file where the object specified by `object_name` is
|
||||||
|
defined, together the start/end indices of the block in the file that defines the object.
|
||||||
"""
|
"""
|
||||||
parts = object_name.split(".")
|
parts = object_name.split(".")
|
||||||
i = 0
|
i = 0
|
||||||
@@ -181,22 +388,91 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
|||||||
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||||
):
|
):
|
||||||
line_index += 1
|
line_index += 1
|
||||||
|
# find the target specified in the current level in `parts` -> increase `indent` so we can search the next
|
||||||
indent += " "
|
indent += " "
|
||||||
|
# the index of the first line in the (currently found) block *body*
|
||||||
line_index += 1
|
line_index += 1
|
||||||
|
|
||||||
if line_index >= len(lines):
|
if line_index >= len(lines):
|
||||||
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
||||||
|
|
||||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
# `indent` is already one level deeper than the (found) class/func block's definition header
|
||||||
start_index = line_index - 1
|
|
||||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
|
||||||
line_index += 1
|
|
||||||
# Clean up empty lines at the end (if any).
|
|
||||||
while len(lines[line_index - 1]) <= 1:
|
|
||||||
line_index -= 1
|
|
||||||
|
|
||||||
code_lines = lines[start_index:line_index]
|
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||||
return "".join(code_lines)
|
# `start_index` is the index of the class/func block's definition header
|
||||||
|
start_index = line_index - 1
|
||||||
|
end_index = find_block_end(lines, start_index, len(indent))
|
||||||
|
|
||||||
|
code = "".join(lines[start_index:end_index])
|
||||||
|
return (code, (lines, start_index, end_index)) if return_indices else code
|
||||||
|
|
||||||
|
|
||||||
|
def replace_code(code: str, replace_pattern: str) -> str:
|
||||||
|
"""Replace `code` by a pattern of the form `with X1->X2,Y1->Y2,Z1->Z2`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code (`str`): The code to be modified.
|
||||||
|
replace_pattern (`str`): The pattern used to modify `code`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The modified code.
|
||||||
|
"""
|
||||||
|
if len(replace_pattern) > 0:
|
||||||
|
patterns = replace_pattern.replace("with", "").split(",")
|
||||||
|
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
||||||
|
for pattern in patterns:
|
||||||
|
if pattern is None:
|
||||||
|
continue
|
||||||
|
obj1, obj2, option = pattern.groups()
|
||||||
|
code = re.sub(obj1, obj2, code)
|
||||||
|
if option.strip() == "all-casing":
|
||||||
|
code = re.sub(obj1.lower(), obj2.lower(), code)
|
||||||
|
code = re.sub(obj1.upper(), obj2.upper(), code)
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def find_code_and_splits(object_name: str, base_path: str, buffer: dict = None):
|
||||||
|
"""Find the code of an object (specified by `object_name`) and split it into blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_name (`str`):
|
||||||
|
The name of the object, e.g. `transformers.models.bert.modeling_bert.BertAttention` or
|
||||||
|
`tests.models.llama.test_modeling_llama.LlamaModelTest.test_config`.
|
||||||
|
base_path (`str`):
|
||||||
|
The path to the base directory within which the search will be performed. It could be either
|
||||||
|
`TRANSFORMERS_PATH` or `MODEL_TEST_PATH`.
|
||||||
|
buffer (`dict`, *optional*):
|
||||||
|
The buffer used to store the previous results in order to speed up the process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
lines (`List[str]`):
|
||||||
|
The lines of the whole file where the object is defined.
|
||||||
|
code (`str`):
|
||||||
|
The object's code.
|
||||||
|
code_splits (`List[Tuple[str, int, int]]`):
|
||||||
|
`code` splitted into blocks. See `split_code_into_blocks`.
|
||||||
|
"""
|
||||||
|
if buffer is None:
|
||||||
|
buffer = {}
|
||||||
|
|
||||||
|
if (object_name, base_path) in buffer:
|
||||||
|
lines, code, code_splits = buffer[(object_name, base_path)]
|
||||||
|
else:
|
||||||
|
code, (lines, target_start_index, target_end_index) = find_code_in_transformers(
|
||||||
|
object_name, base_path=base_path, return_indices=True
|
||||||
|
)
|
||||||
|
indent = get_indent(code)
|
||||||
|
|
||||||
|
# Split the code into blocks
|
||||||
|
# `indent` is the indent of the class/func definition header, but `code_splits` expects the indent level of the
|
||||||
|
# block body.
|
||||||
|
code_splits = split_code_into_blocks(
|
||||||
|
lines, target_start_index, target_end_index, len(indent) + 4, backtrace=True
|
||||||
|
)
|
||||||
|
buffer[(object_name, base_path)] = lines, code, code_splits
|
||||||
|
|
||||||
|
return lines, code, code_splits
|
||||||
|
|
||||||
|
|
||||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||||
@@ -285,7 +561,7 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
|
|||||||
diff_index += 1
|
diff_index += 1
|
||||||
|
|
||||||
|
|
||||||
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
|
def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = None) -> Optional[List[Tuple[str, int]]]:
|
||||||
"""
|
"""
|
||||||
Check if the code commented as a copy in a file matches the original.
|
Check if the code commented as a copy in a file matches the original.
|
||||||
|
|
||||||
@@ -294,11 +570,15 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||||||
The name of the file to check.
|
The name of the file to check.
|
||||||
overwrite (`bool`, *optional*, defaults to `False`):
|
overwrite (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to overwrite the copies when they don't match.
|
Whether or not to overwrite the copies when they don't match.
|
||||||
|
buffer (`dict`, *optional*):
|
||||||
|
The buffer used to store the previous results in order to speed up the process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
|
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
|
||||||
with the name of the object having a diff and the line number where theere is the first diff.
|
with the name of the object having a diff and the line number where theere is the first diff.
|
||||||
"""
|
"""
|
||||||
|
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
|
||||||
|
|
||||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
diffs = []
|
diffs = []
|
||||||
@@ -317,16 +597,31 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||||||
# There is some copied code here, let's retrieve the original.
|
# There is some copied code here, let's retrieve the original.
|
||||||
indent, object_name, replace_pattern = search.groups()
|
indent, object_name, replace_pattern = search.groups()
|
||||||
|
|
||||||
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
|
# Find the file lines, the object's code, and its blocks
|
||||||
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
|
target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits(
|
||||||
|
object_name, base_path, buffer=buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
# code replaced by the patterns
|
||||||
|
theoretical_code_blocks = OrderedDict()
|
||||||
|
for name, start, end in theoretical_code_splits:
|
||||||
|
name = replace_code(name, replace_pattern)
|
||||||
|
code = "".join(target_lines[start:end])
|
||||||
|
code = replace_code(code, replace_pattern)
|
||||||
|
theoretical_code_blocks[name] = code
|
||||||
|
|
||||||
theoretical_indent = get_indent(theoretical_code)
|
theoretical_indent = get_indent(theoretical_code)
|
||||||
|
|
||||||
|
# `start_index` is the index of the first line (the definition header) after `# Copied from`.
|
||||||
|
# (`indent != theoretical_indent` doesn't seem to occur so far, not sure what this case is for.)
|
||||||
start_index = line_index + 1 if indent == theoretical_indent else line_index
|
start_index = line_index + 1 if indent == theoretical_indent else line_index
|
||||||
|
# enter the block body
|
||||||
line_index = start_index + 1
|
line_index = start_index + 1
|
||||||
|
|
||||||
subcode = "\n".join(theoretical_code.split("\n")[1:])
|
subcode = "\n".join(theoretical_code.split("\n")[1:])
|
||||||
indent = get_indent(subcode)
|
indent = get_indent(subcode)
|
||||||
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
||||||
|
# We can't call `find_block_end` directly as there is sth. special `# End copy"` here.
|
||||||
should_continue = True
|
should_continue = True
|
||||||
while line_index < len(lines) and should_continue:
|
while line_index < len(lines) and should_continue:
|
||||||
line_index += 1
|
line_index += 1
|
||||||
@@ -336,33 +631,118 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||||||
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
|
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
|
||||||
# used.
|
# used.
|
||||||
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||||
|
# `line_index` is outside the block
|
||||||
# Clean up empty lines at the end (if any).
|
# Clean up empty lines at the end (if any).
|
||||||
while len(lines[line_index - 1]) <= 1:
|
while len(lines[line_index - 1]) <= 1:
|
||||||
line_index -= 1
|
line_index -= 1
|
||||||
|
|
||||||
observed_code_lines = lines[start_index:line_index]
|
# Split the observed code into blocks
|
||||||
observed_code = "".join(observed_code_lines)
|
observed_code_splits = split_code_into_blocks(lines, start_index, line_index, len(indent), backtrace=True)
|
||||||
|
|
||||||
# Before comparing, use the `replace_pattern` on the original code.
|
is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ")
|
||||||
if len(replace_pattern) > 0:
|
# sanity check
|
||||||
patterns = replace_pattern.replace("with", "").split(",")
|
_sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class)
|
||||||
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
|
||||||
for pattern in patterns:
|
# observed code in a structured way (a dict mapping block names to blocks' code)
|
||||||
if pattern is None:
|
observed_code_blocks = OrderedDict()
|
||||||
continue
|
for name, start, end in observed_code_splits:
|
||||||
obj1, obj2, option = pattern.groups()
|
code = "".join(lines[start:end])
|
||||||
theoretical_code = re.sub(obj1, obj2, theoretical_code)
|
observed_code_blocks[name] = code
|
||||||
if option.strip() == "all-casing":
|
|
||||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
# Below, we change some names in `theoretical_code_blocks` and `observed_code_blocks`. These mappings map the
|
||||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
# original names to the modified names: this is used to restore the original order of the code blocks.
|
||||||
|
name_mappings_1 = {k: k for k in theoretical_code_blocks.keys()}
|
||||||
|
name_mappings_2 = {k: k for k in observed_code_blocks.keys()}
|
||||||
|
|
||||||
|
# Update code blocks' name and content:
|
||||||
|
# If `"# Ignore copy"` is found in a block of the observed code:
|
||||||
|
# 1. if it's a block only in the observed code --> add it to the theoretical code.
|
||||||
|
# 2. if it's also in the theoretical code () --> put its content (body) to the corresponding block under the
|
||||||
|
# same name in the theoretical code.
|
||||||
|
# In both cases, we change the name to have a prefix `_ignored_` so we know if we can discard them during the
|
||||||
|
# comparison.
|
||||||
|
ignored_existing_block_index = 0
|
||||||
|
ignored_new_block_index = 0
|
||||||
|
for name in list(observed_code_blocks.keys()):
|
||||||
|
code = observed_code_blocks[name]
|
||||||
|
if "# Ignore copy" in code:
|
||||||
|
if name in theoretical_code_blocks:
|
||||||
|
# in the target --> just copy the content
|
||||||
|
del theoretical_code_blocks[name]
|
||||||
|
theoretical_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
|
||||||
|
name_mappings_1[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
|
||||||
|
|
||||||
|
del observed_code_blocks[name]
|
||||||
|
observed_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
|
||||||
|
name_mappings_2[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
|
||||||
|
ignored_existing_block_index += 1
|
||||||
|
else:
|
||||||
|
# not in the target --> add it
|
||||||
|
theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||||
|
name_mappings_1[
|
||||||
|
f"_ignored_new_block_{ignored_new_block_index}"
|
||||||
|
] = f"_ignored_new_block_{ignored_new_block_index}"
|
||||||
|
|
||||||
|
del observed_code_blocks[name]
|
||||||
|
observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||||
|
name_mappings_2[name] = f"_ignored_new_block_{ignored_new_block_index}"
|
||||||
|
ignored_new_block_index += 1
|
||||||
|
|
||||||
|
# Respect the original block order:
|
||||||
|
# 1. in `theoretical_code_blocks`: the new blocks will follow the existing ones
|
||||||
|
# 2. in `observed_code_blocks`: the original order are kept with names modified potentially. This is necessary
|
||||||
|
# to compute the correct `diff_index` if `overwrite=True` and there is a diff.
|
||||||
|
theoretical_code_blocks = {
|
||||||
|
name_mappings_1[orig_name]: theoretical_code_blocks[name_mappings_1[orig_name]]
|
||||||
|
for orig_name in name_mappings_1
|
||||||
|
}
|
||||||
|
observed_code_blocks = {
|
||||||
|
name_mappings_2[orig_name]: observed_code_blocks[name_mappings_2[orig_name]]
|
||||||
|
for orig_name in name_mappings_2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ignore the blocks specified to be ignored. This is the version used to check if there is a mismatch
|
||||||
|
theoretical_code_blocks_clean = {
|
||||||
|
k: v
|
||||||
|
for k, v in theoretical_code_blocks.items()
|
||||||
|
if not (k.startswith(("_ignored_existing_block_", "_ignored_new_block_")))
|
||||||
|
}
|
||||||
|
theoretical_code = "".join(list(theoretical_code_blocks_clean.values()))
|
||||||
|
|
||||||
|
# stylify `theoretical_code` before compare (this is needed only when `replace_pattern` is not empty)
|
||||||
|
if replace_pattern:
|
||||||
theoretical_code = stylify(theoretical_code)
|
theoretical_code = stylify(theoretical_code)
|
||||||
|
# Remove `\n\n` in `theoretical_code` before compare (so no empty line)
|
||||||
|
while "\n\n" in theoretical_code:
|
||||||
|
theoretical_code = theoretical_code.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
# Compute `observed_code` where we don't include any empty line + keep track the line index between the
|
||||||
|
# original/processed `observed_code` so we can have the correct `diff_index`.
|
||||||
|
idx_to_orig_idx_mapping_for_observed_code_lines = {}
|
||||||
|
idx = -1
|
||||||
|
orig_idx = -1
|
||||||
|
observed_code = ""
|
||||||
|
for name, code in observed_code_blocks.items():
|
||||||
|
if code.endswith("\n"):
|
||||||
|
code = code[:-1]
|
||||||
|
for code_line in code.split("\n"):
|
||||||
|
orig_idx += 1
|
||||||
|
if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):
|
||||||
|
idx += 1
|
||||||
|
observed_code += code_line + "\n"
|
||||||
|
idx_to_orig_idx_mapping_for_observed_code_lines[idx] = orig_idx
|
||||||
|
|
||||||
# Test for a diff and act accordingly.
|
# Test for a diff and act accordingly.
|
||||||
diff_index = check_codes_match(observed_code, theoretical_code)
|
diff_index = check_codes_match(observed_code, theoretical_code)
|
||||||
if diff_index is not None:
|
if diff_index is not None:
|
||||||
|
# switch to the index in the original `observed_code` (i.e. before removing empty lines)
|
||||||
|
diff_index = idx_to_orig_idx_mapping_for_observed_code_lines[diff_index]
|
||||||
diffs.append([object_name, diff_index + start_index + 1])
|
diffs.append([object_name, diff_index + start_index + 1])
|
||||||
if overwrite:
|
if overwrite:
|
||||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
# `theoretical_code_to_write` is a single string but may have several lines.
|
||||||
|
theoretical_code_to_write = stylify("".join(list(theoretical_code_blocks.values())))
|
||||||
|
lines = lines[:start_index] + [theoretical_code_to_write] + lines[line_index:]
|
||||||
|
# Here we treat it as a single entry in `lines`.
|
||||||
line_index = start_index + 1
|
line_index = start_index + 1
|
||||||
|
|
||||||
if overwrite and len(diffs) > 0:
|
if overwrite and len(diffs) > 0:
|
||||||
@@ -373,7 +753,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||||||
return diffs
|
return diffs
|
||||||
|
|
||||||
|
|
||||||
def check_copies(overwrite: bool = False):
|
def check_copies(overwrite: bool = False, file: str = None):
|
||||||
"""
|
"""
|
||||||
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
||||||
READMEs are consistent.
|
READMEs are consistent.
|
||||||
@@ -381,14 +761,21 @@ def check_copies(overwrite: bool = False):
|
|||||||
Args:
|
Args:
|
||||||
overwrite (`bool`, *optional*, defaults to `False`):
|
overwrite (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to overwrite the copies when they don't match.
|
Whether or not to overwrite the copies when they don't match.
|
||||||
|
file (`bool`, *optional*):
|
||||||
|
The path to a specific file to check and/or fix.
|
||||||
"""
|
"""
|
||||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
buffer = {}
|
||||||
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
|
|
||||||
all_files = list(all_files) + list(all_test_files)
|
if file is None:
|
||||||
|
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||||
|
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
|
||||||
|
all_files = list(all_files) + list(all_test_files)
|
||||||
|
else:
|
||||||
|
all_files = [file]
|
||||||
|
|
||||||
diffs = []
|
diffs = []
|
||||||
for filename in all_files:
|
for filename in all_files:
|
||||||
new_diffs = is_copy_consistent(filename, overwrite)
|
new_diffs = is_copy_consistent(filename, overwrite, buffer)
|
||||||
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
||||||
if not overwrite and len(diffs) > 0:
|
if not overwrite and len(diffs) > 0:
|
||||||
diff = "\n".join(diffs)
|
diff = "\n".join(diffs)
|
||||||
@@ -733,9 +1120,10 @@ def check_readme(overwrite: bool = False):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--file", type=str, default=None, help="A specific file to check and/or fix")
|
||||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
check_readme(args.fix_and_overwrite)
|
check_readme(args.fix_and_overwrite)
|
||||||
check_copies(args.fix_and_overwrite)
|
check_copies(args.fix_and_overwrite, args.file)
|
||||||
check_full_copies(args.fix_and_overwrite)
|
check_full_copies(args.fix_and_overwrite)
|
||||||
|
|||||||
Reference in New Issue
Block a user