From 4ba248748f779b6eb1317734a2493307b3c26431 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 30 Sep 2020 04:05:14 -0400 Subject: [PATCH] Get a better error when check_copies fails (#7457) * Get a better error when check_copies fails * Fix tests --- tests/test_utils_check_copies.py | 2 +- utils/check_copies.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_utils_check_copies.py b/tests/test_utils_check_copies.py index 2c6baba32f..24d05f7c4f 100644 --- a/tests/test_utils_check_copies.py +++ b/tests/test_utils_check_copies.py @@ -55,7 +55,7 @@ class CopyCheckTester(unittest.TestCase): with open(fname, "w") as f: f.write(code) if overwrite_result is None: - self.assertTrue(check_copies.is_copy_consistent(fname)) + self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0) else: check_copies.is_copy_consistent(f.name, overwrite=True) with open(fname, "r") as f: diff --git a/utils/check_copies.py b/utils/check_copies.py index fedd4357fe..2d4b9fbc06 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False): """ with open(filename, "r", encoding="utf-8") as f: lines = f.readlines() - found_diff = False + diffs = [] line_index = 0 # Not a foor loop cause `lines` is going to change (if `overwrite=True`). while line_index < len(lines): @@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False): # Test for a diff and act accordingly. if observed_code != theoretical_code: - found_diff = True + diffs.append([object_name, start_index]) if overwrite: lines = lines[:start_index] + [theoretical_code] + lines[line_index:] line_index = start_index + 1 - if overwrite and found_diff: + if overwrite and len(diffs) > 0: # Warn the user a file has been modified. print(f"Detected changes, rewriting {filename}.") with open(filename, "w", encoding="utf-8") as f: f.writelines(lines) - return not found_diff + return diffs def check_copies(overwrite: bool = False): all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) diffs = [] for filename in all_files: - consistent = is_copy_consistent(filename, overwrite) - if not consistent: - diffs.append(filename) + new_diffs = is_copy_consistent(filename, overwrite) + diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] if not overwrite and len(diffs) > 0: diff = "\n".join(diffs) raise Exception( - "Found copy inconsistencies in the following files:\n" + "Found the follwing copy inconsistencies:\n" + diff + "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them." )