Get a better error when check_copies fails (#7457)
* Get a better error when check_copies fails * Fix tests
This commit is contained in:
@@ -55,7 +55,7 @@ class CopyCheckTester(unittest.TestCase):
|
|||||||
with open(fname, "w") as f:
|
with open(fname, "w") as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
if overwrite_result is None:
|
if overwrite_result is None:
|
||||||
self.assertTrue(check_copies.is_copy_consistent(fname))
|
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
|
||||||
else:
|
else:
|
||||||
check_copies.is_copy_consistent(f.name, overwrite=True)
|
check_copies.is_copy_consistent(f.name, overwrite=True)
|
||||||
with open(fname, "r") as f:
|
with open(fname, "r") as f:
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False):
|
|||||||
"""
|
"""
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
found_diff = False
|
diffs = []
|
||||||
line_index = 0
|
line_index = 0
|
||||||
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
|
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
|
||||||
while line_index < len(lines):
|
while line_index < len(lines):
|
||||||
@@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False):
|
|||||||
|
|
||||||
# Test for a diff and act accordingly.
|
# Test for a diff and act accordingly.
|
||||||
if observed_code != theoretical_code:
|
if observed_code != theoretical_code:
|
||||||
found_diff = True
|
diffs.append([object_name, start_index])
|
||||||
if overwrite:
|
if overwrite:
|
||||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||||
line_index = start_index + 1
|
line_index = start_index + 1
|
||||||
|
|
||||||
if overwrite and found_diff:
|
if overwrite and len(diffs) > 0:
|
||||||
# Warn the user a file has been modified.
|
# Warn the user a file has been modified.
|
||||||
print(f"Detected changes, rewriting {filename}.")
|
print(f"Detected changes, rewriting {filename}.")
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
f.writelines(lines)
|
f.writelines(lines)
|
||||||
return not found_diff
|
return diffs
|
||||||
|
|
||||||
|
|
||||||
def check_copies(overwrite: bool = False):
|
def check_copies(overwrite: bool = False):
|
||||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||||
diffs = []
|
diffs = []
|
||||||
for filename in all_files:
|
for filename in all_files:
|
||||||
consistent = is_copy_consistent(filename, overwrite)
|
new_diffs = is_copy_consistent(filename, overwrite)
|
||||||
if not consistent:
|
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
||||||
diffs.append(filename)
|
|
||||||
if not overwrite and len(diffs) > 0:
|
if not overwrite and len(diffs) > 0:
|
||||||
diff = "\n".join(diffs)
|
diff = "\n".join(diffs)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Found copy inconsistencies in the following files:\n"
|
"Found the follwing copy inconsistencies:\n"
|
||||||
+ diff
|
+ diff
|
||||||
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
|
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user