From 4c722e9e227bb5850172100ea51d8d498adf1aa7 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 29 Jun 2022 08:55:22 -0700 Subject: [PATCH] fix regexes with escape sequence (#17943) --- src/transformers/dynamic_module_utils.py | 8 ++++---- src/transformers/modeling_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 7670c4b668..7baafd214c 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -78,9 +78,9 @@ def get_relative_imports(module_file): content = f.read() # Imports of the form `import .xxx` - relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from .xxx import yyy` - relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) # Unique-ify return list(set(relative_imports)) @@ -122,9 +122,9 @@ def check_imports(filename): content = f.read() # Imports of the form `import xxx` - imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) + imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from xxx import yyy` - imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) + imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f9f57ffe6a..e1621c6e5a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -219,7 +219,7 @@ def dtype_byte_size(dtype): """ if dtype == torch.bool: return 1 / 8 - bit_search = re.search("[^\d](\d+)$", str(dtype)) + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0])