Get a better error when check_copies fails (#7457)
* Get a better error when check_copies fails * Fix tests
This commit is contained in:
@@ -55,7 +55,7 @@ class CopyCheckTester(unittest.TestCase):
|
||||
with open(fname, "w") as f:
|
||||
f.write(code)
|
||||
if overwrite_result is None:
|
||||
self.assertTrue(check_copies.is_copy_consistent(fname))
|
||||
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
|
||||
else:
|
||||
check_copies.is_copy_consistent(f.name, overwrite=True)
|
||||
with open(fname, "r") as f:
|
||||
|
||||
@@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
found_diff = False
|
||||
diffs = []
|
||||
line_index = 0
|
||||
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
|
||||
while line_index < len(lines):
|
||||
@@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False):
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
if observed_code != theoretical_code:
|
||||
found_diff = True
|
||||
diffs.append([object_name, start_index])
|
||||
if overwrite:
|
||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||
line_index = start_index + 1
|
||||
|
||||
if overwrite and found_diff:
|
||||
if overwrite and len(diffs) > 0:
|
||||
# Warn the user a file has been modified.
|
||||
print(f"Detected changes, rewriting {filename}.")
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.writelines(lines)
|
||||
return not found_diff
|
||||
return diffs
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
consistent = is_copy_consistent(filename, overwrite)
|
||||
if not consistent:
|
||||
diffs.append(filename)
|
||||
new_diffs = is_copy_consistent(filename, overwrite)
|
||||
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found copy inconsistencies in the following files:\n"
|
||||
"Found the follwing copy inconsistencies:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user