fix regexes with escape sequence (#17943)
This commit is contained in:
@@ -78,9 +78,9 @@ def get_relative_imports(module_file):
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Imports of the form `import .xxx`
|
# Imports of the form `import .xxx`
|
||||||
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
||||||
# Imports of the form `from .xxx import yyy`
|
# Imports of the form `from .xxx import yyy`
|
||||||
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
||||||
# Unique-ify
|
# Unique-ify
|
||||||
return list(set(relative_imports))
|
return list(set(relative_imports))
|
||||||
|
|
||||||
@@ -122,9 +122,9 @@ def check_imports(filename):
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# Imports of the form `import xxx`
|
# Imports of the form `import xxx`
|
||||||
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
||||||
# Imports of the form `from xxx import yyy`
|
# Imports of the form `from xxx import yyy`
|
||||||
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||||
# Only keep the top-level module
|
# Only keep the top-level module
|
||||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||||
|
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ def dtype_byte_size(dtype):
|
|||||||
"""
|
"""
|
||||||
if dtype == torch.bool:
|
if dtype == torch.bool:
|
||||||
return 1 / 8
|
return 1 / 8
|
||||||
bit_search = re.search("[^\d](\d+)$", str(dtype))
|
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||||
if bit_search is None:
|
if bit_search is None:
|
||||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||||
bit_size = int(bit_search.groups()[0])
|
bit_size = int(bit_search.groups()[0])
|
||||||
|
|||||||
Reference in New Issue
Block a user