Support for easier multimodal use of modular (#35056)

* update modular and add examples

* style

* improve example comments

* style

* fix small logic issue for imports

* fix relative order issue when files do not make sense

* Improve comments

* trigger CIs
This commit is contained in:
Cyril Vallez
2024-12-04 15:13:11 +01:00
committed by GitHub
parent 46df859975
commit 1da1e0d7f2
13 changed files with 2424 additions and 103 deletions

View File

@@ -18,7 +18,7 @@ import importlib
import os
import re
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from typing import Dict, Set
import libcst as cst
@@ -48,7 +48,7 @@ def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
raise ValueError(f"Cannot open file associated with {module_name} module.")
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
@@ -58,20 +58,40 @@ def get_module_source_from_name(module_name: str) -> str:
def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
def replace(match):
word = match.group(0)
result = patterns.get(word, default_name)
return result
matched_pattern = match.group(1)
next_char = match.group(2)
new_pattern = patterns.get(matched_pattern, default_name)
# In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char
# The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper()
if len(patterns) == 2 and matched_pattern.isupper():
if not next_char.isalpha():
# `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased
new_pattern = patterns[matched_pattern.lower()].upper()
return new_pattern + next_char
return compiled_regex.sub(replace, text)
def convert_to_camelcase(text, old_name: str, default_old_name: str):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
return result
def get_cased_name(lowercase_name: str) -> str:
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
if lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
else:
return "".join(x.title() for x in lowercase_name.split("_"))
def get_lowercase_name(cased_name: str) -> str:
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`."""
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()}
if cased_name + "Config" in inverse_mapping:
return inverse_mapping[cased_name + "Config"]
else:
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)])
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
@@ -84,43 +104,47 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
):
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.cased_new_name = get_cased_name(self.new_name)
self.cased_old_name = get_cased_name(self.old_name)
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
# For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry
self.cased_old_name: self.cased_new_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
# In case new_name is a prefix alias, and not the original new model name
self.original_new_model_name = original_new_model_name
self.only_doc = only_doc
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
def _replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
return updated_node.with_changes(name=cst.Name(new_name))
@m.leave(m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
return self._replace_name(original_node, updated_node)
def leave_Name(self, original_node, updated_node):
if not self.only_doc:
return self._replace_name(original_node, updated_node)
return updated_node
def leave_ImportFrom(self, original_node, updated_node):
"""The imports from other file types (configuration, processing etc) should use original model name."""
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
patterns = "|".join(ALL_FILE_TYPES)
regex = rf"({patterns})_{self.new_name}"
new_source = re.sub(
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value
)
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source))
return updated_node
DOCSTRING_NODE = m.SimpleStatementLine(
@@ -760,10 +784,12 @@ class ModelFileMapper(ModuleMapper):
remaining_dependencies.remove(dep)
relative_order[dep] = idx
idx += 1
# Add the class itself
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Add the class itself (it can sometimes already be present if the order of classes in the source file
# does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...)
if class_name in remaining_dependencies:
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Now add what still remains
remaining_dependencies = tuple(remaining_dependencies)
@@ -859,7 +885,24 @@ class ModelFileMapper(ModuleMapper):
return mapper
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
def common_partial_suffix(str1: str, str2: str) -> str:
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
we do not consider it a common suffix and return `""`"""
common_suffix = ""
for i in range(1, min(len(str1), len(str2)) + 1):
if str1[-i] == str2[-i]:
common_suffix = str1[-i] + common_suffix
else:
break
# We do not allow full string suffix
if common_suffix == str1 or common_suffix == str2:
common_suffix = ""
return common_suffix
def replace_class_node(
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
):
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
- start from the base class node of the inherited class (a cst.Node)
@@ -889,6 +932,36 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
original_node = mapper.classes[renamed_super_class]
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
new_name = class_node.name
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
if new_name.value != renamed_super_class:
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
# Note that this works even without common prefix, in which case it does not replace anything
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
temp_module = cst.Module(body=[original_node])
original_node = temp_module.visit(
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
).body[0]
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
additional_bases = [base for base in all_bases if base != original_super_class]
new_bases = []
for original_base in original_node.bases:
new_base = original_base
# we only potentially switch base for Name-based bases, not Attribute
if m.matches(original_base.value, m.Name()):
original_base_name = original_base.value.value
for additional_base_name in additional_bases:
suffix = common_partial_suffix(original_base_name, additional_base_name)
if len(suffix) > 0 and suffix[0].isupper():
new_name_node = original_base.value.with_changes(value=additional_base_name)
new_base = original_base.with_changes(value=new_name_node)
break
new_bases.append(new_base)
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
@@ -942,12 +1015,17 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
docstring_node = [
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
]
else:
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
@@ -970,10 +1048,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
# Use decorators redefined in `modular_xxx.py` if any
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
name = class_node.name
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
return original_node.with_changes(
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
)
TYPE_TO_FILE_TYPE = {
@@ -1014,14 +1092,18 @@ VARIABLES_AT_THE_BEGINNING = (
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
def append_new_import_node(
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode]
):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`.
Also modifies `added_names` in-place accordingly."""
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
if name_value not in unused_imports:
if name_value not in unused_imports and name_value not in added_names:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
added_names.add(name_value)
if len(names_to_keep) > 0:
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
imports_to_keep.append(new_node)
@@ -1036,40 +1118,38 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
scopes = set(wrapper.resolve(ScopeProvider).values())
unused_imports = set()
import_ref_count = {}
import_ref_count = defaultdict(lambda: 0)
for scope in scopes:
for assignment in scope.assignments:
node = assignment.node
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
ref_count = len(assignment.references)
name = assignment.name
# Similar imports may be redefined, and only used between their 1st and 2nd definition
# so if we already have a ref count > 0, the imports is actually used
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
unused_imports.add(name)
import_ref_count[name] = ref_count
import_ref_count[name] = max(ref_count, import_ref_count[name])
# Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have
# a ref count > 0 at any point, the imports is actually used
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()}
imports_to_keep = []
# We need to keep track of which names were already imported, because some import may be duplicated from multiple sources
# or be both protected and unprotected due to inconsistency between models
added_names = set()
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
for node in all_imports:
if m.matches(node, m.If()): # handle safe imports
new_statements = []
for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements)
append_new_import_node(stmt_node, unused_imports, added_names, new_statements)
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
else:
append_new_import_node(node, unused_imports, imports_to_keep)
append_new_import_node(node, unused_imports, added_names, imports_to_keep)
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
# If the same import is both protected and unprotected, only keep the protected one
for protected_node in protected_import_nodes:
for stmt_node in protected_node.body.body:
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]
# Protected imports always appear at the end of all imports
return usual_import_nodes + protected_import_nodes
@@ -1102,12 +1182,10 @@ class ModularFileMapper(ModuleMapper):
Calling the method `create_modules()` after visit will create all modules based on this modular file.
"""
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
def __init__(self, python_module, new_name):
super().__init__(python_module)
# fmt: off
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
self.given_old_name = given_old_name
self.given_new_name = given_new_name
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
@@ -1191,11 +1269,11 @@ class ModularFileMapper(ModuleMapper):
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
name_prefixes = self.infer_new_model_name()
for file, module in self.model_specific_modules.items():
file_model_name = file.split(".")[-2]
renamer = ReplaceNameTransformer(
file_model_name, self.model_name, self.given_old_name, self.given_new_name
)
new_name = name_prefixes[file]
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
renamed_module = module.visit(renamer)
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
renamed_module,
@@ -1288,6 +1366,87 @@ class ModularFileMapper(ModuleMapper):
return relative_order
def infer_new_model_name(self) -> dict:
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename.
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`,
so we have something like:
```python
class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
pass
```
with the `Text` prefix added to the model name.
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing
the same file multiple times and inconsistencies in the objects added from dependencies.
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies.
"""
prefix_model_name_mapping = defaultdict(Counter)
cased_default_name = get_cased_name(self.model_name)
# Iterate over all new classes to get modeling super classes
for class_name, class_node in self.classes.items():
modeling_bases = [
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects
]
if len(modeling_bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
)
if len(modeling_bases) == 1:
filename = self.model_specific_imported_objects[modeling_bases[0]]
cased_model_name = cased_default_name # the default name prefix
suffix = common_partial_suffix(class_name, modeling_bases[0])
if len(suffix) > 0 and suffix[0].isupper():
cased_model_name = class_name.replace(suffix, "")
prefix_model_name_mapping[filename].update([cased_model_name])
# Check if we found multiple prefixes for some modeling files
final_name_mapping = {}
for file, prefixes_counter in prefix_model_name_mapping.items():
if len(prefixes_counter) > 1:
_, total = prefixes_counter.most_common(1)[0]
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
# if the default name is in the pool of equally used prefixes, use it, otherwise last encountered
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1]
else:
final_name = list(prefixes_counter)[0]
# Check if the prefix can be used without collisions in the names
old_cased_model_name = get_cased_name(file.split(".")[-2])
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
# Raise adequate warning depending on the situation
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
if final_name != cased_default_name and has_prefix_collision:
if len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
f"from '{cased_default_name}') or use a single prefix in all the modular (best)."
)
else:
logger.warning(
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is "
"already present in the source file and will likely cause consistency issues. For this reason we fallback "
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass "
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')"
)
final_name = cased_default_name
elif len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
"in all the modular (best)."
)
final_name_mapping[file] = get_lowercase_name(final_name)
# Check we are not missing imported files
for file in self.model_specific_modules.keys():
if file not in final_name_mapping.keys():
final_name_mapping[file] = self.model_name
return final_name_mapping
def check_dependencies_and_create_import_node(
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
@@ -1338,11 +1497,11 @@ def get_class_node_and_dependencies(
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
the modular that we nay need.
"""
bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects]
if len(bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}."
)
# An exception was already raised if this has len > 1
model_specific_bases = [
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects
]
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
file_type = find_file_type(class_name)
file_to_update = files[file_type]
@@ -1352,19 +1511,17 @@ def get_class_node_and_dependencies(
imported_objects = modular_mapper.imported_objects_per_file[file_type]
# We need to replace the class node with the transformers (modeling file) super class node
if len(bases) == 1:
super_class = bases[0]
if super_class is not None:
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
# Get the mapper corresponding to the inherited class
mapper = modular_mapper.visited_modules[super_file_name]
# Rename the super class according to the exact same rule we used when renaming the whole module
renamer = modular_mapper.renamers[super_file_name]
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)
# Create the new class node
updated_node = replace_class_node(mapper, node, renamed_super_class)
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
# Grab all immediate dependencies of the new node
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
@@ -1468,7 +1625,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
return files
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
def convert_modular_file(modular_file):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
@@ -1478,8 +1635,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name)
cst_transformers = ModularFileMapper(module, model_name)
wrapper.visit(cst_transformers)
for file, module in create_modules(cst_transformers).items():
if module != {}:
@@ -1522,20 +1678,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
default=["src/transformers/models/gemma/modular_gemma.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
@@ -1544,5 +1690,5 @@ if __name__ == "__main__":
for file_name in find_priority_list(args.files_to_parse):
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
converted_files = convert_modular_file(file_name)
converter = save_modeling_file(file_name, converted_files)