Support for easier multimodal use of modular (#35056)
* update modular and add examples * style * improve example comments * style * fix small logic issue for imports * fix relative order issue when files do not make sense * Improve comments * trigger CIs
This commit is contained in:
@@ -18,7 +18,7 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections import Counter, defaultdict, deque
|
||||
from typing import Dict, Set
|
||||
|
||||
import libcst as cst
|
||||
@@ -48,7 +48,7 @@ def get_module_source_from_name(module_name: str) -> str:
|
||||
# Extract the source code from the module name
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
return f"Module {module_name} not found"
|
||||
raise ValueError(f"Cannot open file associated with {module_name} module.")
|
||||
|
||||
with open(spec.origin, "r", encoding="utf-8") as file:
|
||||
source_code = file.read()
|
||||
@@ -58,20 +58,40 @@ def get_module_source_from_name(module_name: str) -> str:
|
||||
def preserve_case_replace(text, patterns: dict, default_name: str):
|
||||
# Create a regex pattern to match all variations
|
||||
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
|
||||
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
|
||||
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
def replace(match):
|
||||
word = match.group(0)
|
||||
result = patterns.get(word, default_name)
|
||||
return result
|
||||
matched_pattern = match.group(1)
|
||||
next_char = match.group(2)
|
||||
new_pattern = patterns.get(matched_pattern, default_name)
|
||||
|
||||
# In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char
|
||||
# The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper()
|
||||
if len(patterns) == 2 and matched_pattern.isupper():
|
||||
if not next_char.isalpha():
|
||||
# `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased
|
||||
new_pattern = patterns[matched_pattern.lower()].upper()
|
||||
|
||||
return new_pattern + next_char
|
||||
|
||||
return compiled_regex.sub(replace, text)
|
||||
|
||||
|
||||
def convert_to_camelcase(text, old_name: str, default_old_name: str):
|
||||
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
|
||||
return result
|
||||
def get_cased_name(lowercase_name: str) -> str:
|
||||
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
|
||||
if lowercase_name in CONFIG_MAPPING_NAMES:
|
||||
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
|
||||
else:
|
||||
return "".join(x.title() for x in lowercase_name.split("_"))
|
||||
|
||||
|
||||
def get_lowercase_name(cased_name: str) -> str:
|
||||
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`."""
|
||||
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()}
|
||||
if cased_name + "Config" in inverse_mapping:
|
||||
return inverse_mapping[cased_name + "Config"]
|
||||
else:
|
||||
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)])
|
||||
|
||||
|
||||
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
@@ -84,43 +104,47 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_name,
|
||||
new_name,
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
):
|
||||
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
self.default_name = "".join(x.title() for x in new_name.split("_"))
|
||||
if self.new_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
|
||||
"Config", ""
|
||||
) # the best source of truth for class names. Could also just use the ones de
|
||||
self.cased_new_name = get_cased_name(self.new_name)
|
||||
self.cased_old_name = get_cased_name(self.old_name)
|
||||
self.patterns = {
|
||||
old_name: new_name,
|
||||
old_name.upper(): new_name.upper(),
|
||||
"".join(x.title() for x in old_name.split("_")): self.default_name,
|
||||
# For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry
|
||||
self.cased_old_name: self.cased_new_name,
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
if self.old_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
|
||||
if self.default_old_name.isupper():
|
||||
self.default_old_name = self.default_old_name.capitalize()
|
||||
# In case new_name is a prefix alias, and not the original new model name
|
||||
self.original_new_model_name = original_new_model_name
|
||||
self.only_doc = only_doc
|
||||
|
||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||
def replace_name(self, original_node, updated_node):
|
||||
def _replace_name(self, original_node, updated_node):
|
||||
if re.findall(r"# Copied from", updated_node.value):
|
||||
return cst.RemoveFromParent()
|
||||
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
|
||||
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name)
|
||||
return updated_node.with_changes(value=update)
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
|
||||
return updated_node.with_changes(name=cst.Name(new_name))
|
||||
@m.leave(m.SimpleString() | m.Comment())
|
||||
def replace_name(self, original_node, updated_node):
|
||||
return self._replace_name(original_node, updated_node)
|
||||
|
||||
def leave_Name(self, original_node, updated_node):
|
||||
if not self.only_doc:
|
||||
return self._replace_name(original_node, updated_node)
|
||||
return updated_node
|
||||
|
||||
def leave_ImportFrom(self, original_node, updated_node):
|
||||
"""The imports from other file types (configuration, processing etc) should use original model name."""
|
||||
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"({patterns})_{self.new_name}"
|
||||
new_source = re.sub(
|
||||
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value
|
||||
)
|
||||
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source))
|
||||
return updated_node
|
||||
|
||||
|
||||
DOCSTRING_NODE = m.SimpleStatementLine(
|
||||
@@ -760,10 +784,12 @@ class ModelFileMapper(ModuleMapper):
|
||||
remaining_dependencies.remove(dep)
|
||||
relative_order[dep] = idx
|
||||
idx += 1
|
||||
# Add the class itself
|
||||
remaining_dependencies.remove(class_name)
|
||||
relative_order[class_name] = idx
|
||||
idx += 1
|
||||
# Add the class itself (it can sometimes already be present if the order of classes in the source file
|
||||
# does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...)
|
||||
if class_name in remaining_dependencies:
|
||||
remaining_dependencies.remove(class_name)
|
||||
relative_order[class_name] = idx
|
||||
idx += 1
|
||||
|
||||
# Now add what still remains
|
||||
remaining_dependencies = tuple(remaining_dependencies)
|
||||
@@ -859,7 +885,24 @@ class ModelFileMapper(ModuleMapper):
|
||||
return mapper
|
||||
|
||||
|
||||
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
|
||||
def common_partial_suffix(str1: str, str2: str) -> str:
|
||||
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
|
||||
we do not consider it a common suffix and return `""`"""
|
||||
common_suffix = ""
|
||||
for i in range(1, min(len(str1), len(str2)) + 1):
|
||||
if str1[-i] == str2[-i]:
|
||||
common_suffix = str1[-i] + common_suffix
|
||||
else:
|
||||
break
|
||||
# We do not allow full string suffix
|
||||
if common_suffix == str1 or common_suffix == str2:
|
||||
common_suffix = ""
|
||||
return common_suffix
|
||||
|
||||
|
||||
def replace_class_node(
|
||||
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
):
|
||||
"""
|
||||
Replace a class node which inherits from another modeling class. This function works in the following way:
|
||||
- start from the base class node of the inherited class (a cst.Node)
|
||||
@@ -889,6 +932,36 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
||||
|
||||
original_node = mapper.classes[renamed_super_class]
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
new_name = class_node.name
|
||||
|
||||
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
|
||||
if new_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
|
||||
# Note that this works even without common prefix, in which case it does not replace anything
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_node])
|
||||
original_node = temp_module.visit(
|
||||
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
|
||||
).body[0]
|
||||
|
||||
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
|
||||
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
|
||||
additional_bases = [base for base in all_bases if base != original_super_class]
|
||||
new_bases = []
|
||||
for original_base in original_node.bases:
|
||||
new_base = original_base
|
||||
# we only potentially switch base for Name-based bases, not Attribute
|
||||
if m.matches(original_base.value, m.Name()):
|
||||
original_base_name = original_base.value.value
|
||||
for additional_base_name in additional_bases:
|
||||
suffix = common_partial_suffix(original_base_name, additional_base_name)
|
||||
if len(suffix) > 0 and suffix[0].isupper():
|
||||
new_name_node = original_base.value.with_changes(value=additional_base_name)
|
||||
new_base = original_base.with_changes(value=new_name_node)
|
||||
break
|
||||
new_bases.append(new_base)
|
||||
|
||||
original_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
|
||||
for f in original_node.body.body
|
||||
@@ -942,12 +1015,17 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||
# Extract the original docstring
|
||||
updated_docstring = func.body[0].value.value
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
|
||||
docstring_node = [
|
||||
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
|
||||
]
|
||||
else:
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||
end_meth.append(func)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
@@ -970,10 +1048,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
|
||||
# Use decorators redefined in `modular_xxx.py` if any
|
||||
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
name = class_node.name
|
||||
|
||||
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
|
||||
return original_node.with_changes(
|
||||
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
|
||||
)
|
||||
|
||||
|
||||
TYPE_TO_FILE_TYPE = {
|
||||
@@ -1014,14 +1092,18 @@ VARIABLES_AT_THE_BEGINNING = (
|
||||
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
|
||||
|
||||
|
||||
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
|
||||
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
|
||||
def append_new_import_node(
|
||||
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode]
|
||||
):
|
||||
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`.
|
||||
Also modifies `added_names` in-place accordingly."""
|
||||
import_node = node.body[0]
|
||||
names_to_keep = []
|
||||
for name in import_node.names:
|
||||
name_value = name.evaluated_name
|
||||
if name_value not in unused_imports:
|
||||
if name_value not in unused_imports and name_value not in added_names:
|
||||
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
|
||||
added_names.add(name_value)
|
||||
if len(names_to_keep) > 0:
|
||||
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
|
||||
imports_to_keep.append(new_node)
|
||||
@@ -1036,40 +1118,38 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
|
||||
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
|
||||
scopes = set(wrapper.resolve(ScopeProvider).values())
|
||||
unused_imports = set()
|
||||
import_ref_count = {}
|
||||
import_ref_count = defaultdict(lambda: 0)
|
||||
for scope in scopes:
|
||||
for assignment in scope.assignments:
|
||||
node = assignment.node
|
||||
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
ref_count = len(assignment.references)
|
||||
name = assignment.name
|
||||
# Similar imports may be redefined, and only used between their 1st and 2nd definition
|
||||
# so if we already have a ref count > 0, the imports is actually used
|
||||
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
|
||||
unused_imports.add(name)
|
||||
import_ref_count[name] = ref_count
|
||||
import_ref_count[name] = max(ref_count, import_ref_count[name])
|
||||
# Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have
|
||||
# a ref count > 0 at any point, the imports is actually used
|
||||
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()}
|
||||
|
||||
imports_to_keep = []
|
||||
# We need to keep track of which names were already imported, because some import may be duplicated from multiple sources
|
||||
# or be both protected and unprotected due to inconsistency between models
|
||||
added_names = set()
|
||||
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
|
||||
for node in all_imports:
|
||||
if m.matches(node, m.If()): # handle safe imports
|
||||
new_statements = []
|
||||
for stmt_node in node.body.body:
|
||||
append_new_import_node(stmt_node, unused_imports, new_statements)
|
||||
append_new_import_node(stmt_node, unused_imports, added_names, new_statements)
|
||||
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
|
||||
if len(new_statements) > 0:
|
||||
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
|
||||
imports_to_keep.append(new_node)
|
||||
existing_protected_statements.update({str(stmt) for stmt in new_statements})
|
||||
else:
|
||||
append_new_import_node(node, unused_imports, imports_to_keep)
|
||||
append_new_import_node(node, unused_imports, added_names, imports_to_keep)
|
||||
|
||||
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
|
||||
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
|
||||
# If the same import is both protected and unprotected, only keep the protected one
|
||||
for protected_node in protected_import_nodes:
|
||||
for stmt_node in protected_node.body.body:
|
||||
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]
|
||||
|
||||
# Protected imports always appear at the end of all imports
|
||||
return usual_import_nodes + protected_import_nodes
|
||||
@@ -1102,12 +1182,10 @@ class ModularFileMapper(ModuleMapper):
|
||||
Calling the method `create_modules()` after visit will create all modules based on this modular file.
|
||||
"""
|
||||
|
||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||
def __init__(self, python_module, new_name):
|
||||
super().__init__(python_module)
|
||||
# fmt: off
|
||||
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
|
||||
self.given_old_name = given_old_name
|
||||
self.given_new_name = given_new_name
|
||||
|
||||
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
|
||||
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
|
||||
@@ -1191,11 +1269,11 @@ class ModularFileMapper(ModuleMapper):
|
||||
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
self.visited_modules = {}
|
||||
self.renamers = {}
|
||||
name_prefixes = self.infer_new_model_name()
|
||||
for file, module in self.model_specific_modules.items():
|
||||
file_model_name = file.split(".")[-2]
|
||||
renamer = ReplaceNameTransformer(
|
||||
file_model_name, self.model_name, self.given_old_name, self.given_new_name
|
||||
)
|
||||
new_name = name_prefixes[file]
|
||||
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
|
||||
renamed_module = module.visit(renamer)
|
||||
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
|
||||
renamed_module,
|
||||
@@ -1288,6 +1366,87 @@ class ModularFileMapper(ModuleMapper):
|
||||
|
||||
return relative_order
|
||||
|
||||
def infer_new_model_name(self) -> dict:
|
||||
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename.
|
||||
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`,
|
||||
so we have something like:
|
||||
```python
|
||||
class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
|
||||
pass
|
||||
```
|
||||
with the `Text` prefix added to the model name.
|
||||
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing
|
||||
the same file multiple times and inconsistencies in the objects added from dependencies.
|
||||
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also
|
||||
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies.
|
||||
"""
|
||||
prefix_model_name_mapping = defaultdict(Counter)
|
||||
cased_default_name = get_cased_name(self.model_name)
|
||||
# Iterate over all new classes to get modeling super classes
|
||||
for class_name, class_node in self.classes.items():
|
||||
modeling_bases = [
|
||||
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects
|
||||
]
|
||||
if len(modeling_bases) > 1:
|
||||
raise ValueError(
|
||||
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
|
||||
)
|
||||
if len(modeling_bases) == 1:
|
||||
filename = self.model_specific_imported_objects[modeling_bases[0]]
|
||||
cased_model_name = cased_default_name # the default name prefix
|
||||
suffix = common_partial_suffix(class_name, modeling_bases[0])
|
||||
if len(suffix) > 0 and suffix[0].isupper():
|
||||
cased_model_name = class_name.replace(suffix, "")
|
||||
prefix_model_name_mapping[filename].update([cased_model_name])
|
||||
|
||||
# Check if we found multiple prefixes for some modeling files
|
||||
final_name_mapping = {}
|
||||
for file, prefixes_counter in prefix_model_name_mapping.items():
|
||||
if len(prefixes_counter) > 1:
|
||||
_, total = prefixes_counter.most_common(1)[0]
|
||||
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
|
||||
# if the default name is in the pool of equally used prefixes, use it, otherwise last encountered
|
||||
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1]
|
||||
else:
|
||||
final_name = list(prefixes_counter)[0]
|
||||
# Check if the prefix can be used without collisions in the names
|
||||
old_cased_model_name = get_cased_name(file.split(".")[-2])
|
||||
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
|
||||
# Raise adequate warning depending on the situation
|
||||
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
|
||||
if final_name != cased_default_name and has_prefix_collision:
|
||||
if len(prefixes_counter) > 1:
|
||||
logger.warning(
|
||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
|
||||
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
|
||||
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
|
||||
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
|
||||
f"from '{cased_default_name}') or use a single prefix in all the modular (best)."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is "
|
||||
"already present in the source file and will likely cause consistency issues. For this reason we fallback "
|
||||
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass "
|
||||
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')"
|
||||
)
|
||||
final_name = cased_default_name
|
||||
elif len(prefixes_counter) > 1:
|
||||
logger.warning(
|
||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
|
||||
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
|
||||
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
|
||||
"in all the modular (best)."
|
||||
)
|
||||
final_name_mapping[file] = get_lowercase_name(final_name)
|
||||
|
||||
# Check we are not missing imported files
|
||||
for file in self.model_specific_modules.keys():
|
||||
if file not in final_name_mapping.keys():
|
||||
final_name_mapping[file] = self.model_name
|
||||
|
||||
return final_name_mapping
|
||||
|
||||
|
||||
def check_dependencies_and_create_import_node(
|
||||
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
|
||||
@@ -1338,11 +1497,11 @@ def get_class_node_and_dependencies(
|
||||
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
|
||||
the modular that we nay need.
|
||||
"""
|
||||
bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects]
|
||||
if len(bases) > 1:
|
||||
raise ValueError(
|
||||
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}."
|
||||
)
|
||||
# An exception was already raised if this has len > 1
|
||||
model_specific_bases = [
|
||||
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects
|
||||
]
|
||||
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
|
||||
|
||||
file_type = find_file_type(class_name)
|
||||
file_to_update = files[file_type]
|
||||
@@ -1352,19 +1511,17 @@ def get_class_node_and_dependencies(
|
||||
imported_objects = modular_mapper.imported_objects_per_file[file_type]
|
||||
|
||||
# We need to replace the class node with the transformers (modeling file) super class node
|
||||
if len(bases) == 1:
|
||||
super_class = bases[0]
|
||||
if super_class is not None:
|
||||
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
|
||||
|
||||
# Get the mapper corresponding to the inherited class
|
||||
mapper = modular_mapper.visited_modules[super_file_name]
|
||||
# Rename the super class according to the exact same rule we used when renaming the whole module
|
||||
renamer = modular_mapper.renamers[super_file_name]
|
||||
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
|
||||
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)
|
||||
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)
|
||||
|
||||
# Create the new class node
|
||||
updated_node = replace_class_node(mapper, node, renamed_super_class)
|
||||
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
|
||||
|
||||
# Grab all immediate dependencies of the new node
|
||||
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
|
||||
@@ -1468,7 +1625,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
|
||||
return files
|
||||
|
||||
|
||||
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
def convert_modular_file(modular_file):
|
||||
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||
output = {}
|
||||
if pattern is not None:
|
||||
@@ -1478,8 +1635,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
|
||||
code = file.read()
|
||||
module = cst.parse_module(code)
|
||||
wrapper = MetadataWrapper(module)
|
||||
if cst_transformers is None:
|
||||
cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name)
|
||||
cst_transformers = ModularFileMapper(module, model_name)
|
||||
wrapper.visit(cst_transformers)
|
||||
for file, module in create_modules(cst_transformers).items():
|
||||
if module != {}:
|
||||
@@ -1522,20 +1678,10 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
|
||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_model_name",
|
||||
required=False,
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new_model_name",
|
||||
required=False,
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == ["all"]:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
@@ -1544,5 +1690,5 @@ if __name__ == "__main__":
|
||||
for file_name in find_priority_list(args.files_to_parse):
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
|
||||
converted_files = convert_modular_file(file_name)
|
||||
converter = save_modeling_file(file_name, converted_files)
|
||||
|
||||
Reference in New Issue
Block a user