Get a better error when check_copies fails (#7457)

* Get a better error when check_copies fails

* Fix tests
This commit is contained in:
Sylvain Gugger
2020-09-30 04:05:14 -04:00
committed by GitHub
parent bef0175168
commit 4ba248748f
2 changed files with 8 additions and 9 deletions

View File

@@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False):
"""
with open(filename, "r", encoding="utf-8") as f:
lines = f.readlines()
found_diff = False
diffs = []
line_index = 0
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
@@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False):
# Test for a diff and act accordingly.
if observed_code != theoretical_code:
found_diff = True
diffs.append([object_name, start_index])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
if overwrite and found_diff:
if overwrite and len(diffs) > 0:
# Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.")
with open(filename, "w", encoding="utf-8") as f:
f.writelines(lines)
return not found_diff
return diffs
def check_copies(overwrite: bool = False):
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
consistent = is_copy_consistent(filename, overwrite)
if not consistent:
diffs.append(filename)
new_diffs = is_copy_consistent(filename, overwrite)
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found copy inconsistencies in the following files:\n"
"Found the follwing copy inconsistencies:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
)