From 04fd783cc50bcc6744634e7300b3828b38a4dc79 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 8 Feb 2021 04:58:25 -0500 Subject: [PATCH] Check copies match full class/function names (#10030) --- utils/check_copies.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/check_copies.py b/utils/check_copies.py index 4837371dcb..eabd10cc9b 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -36,7 +36,8 @@ def find_code_in_transformers(object_name): module = parts[i] while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")): i += 1 - module = os.path.join(module, parts[i]) + if i < len(parts): + module = os.path.join(module, parts[i]) if i >= len(parts): raise ValueError( f"`object_name` should begin with the name of a module of transformers but got {object_name}." @@ -49,7 +50,9 @@ def find_code_in_transformers(object_name): indent = "" line_index = 0 for name in parts[i + 1 :]: - while line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}", lines[line_index]) is None: + while ( + line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None + ): line_index += 1 indent += " " line_index += 1