Add support for inheritance from class with different suffix in modular (#34077)
* add support for different suffix in modular * add dummy example, pull new changes for modular * nide lines order change
This commit is contained in:
@@ -204,7 +204,15 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
|
||||
def __init__(
|
||||
self,
|
||||
old_name,
|
||||
new_name,
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
old_class_name: str = None,
|
||||
new_class_name: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
@@ -220,6 +228,18 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
if self.old_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
|
||||
if self.default_old_name.isupper():
|
||||
self.default_old_name = self.default_old_name.capitalize()
|
||||
if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns:
|
||||
# In last recourse, when the suffix of the new class is not the same as the old class,
|
||||
# and if the old and new classes start with the default name, we keep the default class name
|
||||
# and replace the old suffix with the new one.
|
||||
# Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`
|
||||
# where a model extends another model, but is used for a different task.
|
||||
if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name):
|
||||
self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :]
|
||||
|
||||
def preserve_case_replace(self, text):
|
||||
# Create a regex pattern to match all variations
|
||||
@@ -235,7 +255,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
|
||||
def convert_to_camelcase(self, text):
|
||||
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
|
||||
result = re.sub(
|
||||
rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
|
||||
)
|
||||
return result
|
||||
|
||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||
@@ -249,9 +271,24 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
|
||||
|
||||
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||
def find_classes_in_file(
|
||||
module: cst.Module,
|
||||
old_id="llama",
|
||||
new_id="gemma",
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
old_class_name=None,
|
||||
new_class_name=None,
|
||||
):
|
||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
|
||||
transformer = ReplaceNameTransformer(
|
||||
old_id,
|
||||
new_id,
|
||||
given_old_name=given_old_name,
|
||||
given_new_name=given_new_name,
|
||||
old_class_name=old_class_name,
|
||||
new_class_name=new_class_name,
|
||||
)
|
||||
new_module = module.visit(transformer)
|
||||
|
||||
wrapper = MetadataWrapper(new_module)
|
||||
@@ -868,7 +905,7 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if list_dependencies == []:
|
||||
if len(list_dependencies) == 0:
|
||||
# so, maybe standard renaming did not work (the class name is different)
|
||||
# we try with another renaming pattern
|
||||
potential_given_name = get_new_part(class_name, super_class)
|
||||
@@ -884,6 +921,30 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if len(list_dependencies) == 0:
|
||||
# last recourse, if the suffix of the new class is different from the one of the super class
|
||||
# e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection
|
||||
# we try with another renaming pattern
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name],
|
||||
model_name,
|
||||
self.model_name,
|
||||
self.given_old_name,
|
||||
self.given_new_name,
|
||||
super_class,
|
||||
class_name,
|
||||
)
|
||||
visited_module[super_file_name] = class_finder
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if len(list_dependencies) == 0:
|
||||
raise ValueError(
|
||||
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||
)
|
||||
|
||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||
start_insert_idx = self.global_scope_index
|
||||
@@ -917,12 +978,6 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||
)
|
||||
|
||||
# Now, if a class was defined without parents, we look for the name
|
||||
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
|
||||
|
||||
Reference in New Issue
Block a user