Fixes for Modular Converter on Windows (#34266)

* Separator in regex

* Standardize separator for relative path in auto generated message

* open() encoding

* Replace `\` on `os.path.abspath`

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
hlky
2024-10-29 10:40:41 +00:00
committed by GitHub
parent 626c610a4d
commit 9e3d704e23

View File

@@ -56,7 +56,7 @@ def get_module_source_from_name(module_name: str) -> str:
if spec is None or spec.origin is None: if spec is None or spec.origin is None:
return f"Module {module_name} not found" return f"Module {module_name} not found"
with open(spec.origin, "r") as file: with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read() source_code = file.read()
return source_code return source_code
@@ -1132,7 +1132,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
if pattern is not None: if pattern is not None:
model_name = pattern.groups()[0] model_name = pattern.groups()[0]
# Parse the Python file # Parse the Python file
with open(modular_file, "r") as file: with open(modular_file, "r", encoding="utf-8") as file:
code = file.read() code = file.read()
module = cst.parse_module(code) module = cst.parse_module(code)
wrapper = MetadataWrapper(module) wrapper = MetadataWrapper(module)
@@ -1143,7 +1143,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
if node != {}: if node != {}:
# Get relative path starting from src/transformers/ # Get relative path starting from src/transformers/
relative_path = re.search( relative_path = re.search(
rf"(src{os.sep}transformers{os.sep}.*|examples{os.sep}.*)", os.path.abspath(modular_file) r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
).group(1) ).group(1)
header = AUTO_GENERATED_MESSAGE.format( header = AUTO_GENERATED_MESSAGE.format(
@@ -1164,7 +1164,7 @@ def save_modeling_file(modular_file, converted_file):
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
) )
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0: if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f: with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][0]) f.write(converted_file[file_type][0])
else: else:
non_comment_lines = len( non_comment_lines = len(
@@ -1172,7 +1172,7 @@ def save_modeling_file(modular_file, converted_file):
) )
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0: if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains errors, it's written without formatting") logger.warning("The modeling code contains errors, it's written without formatting")
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f: with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][1]) f.write(converted_file[file_type][1])