Add support for inheritance from class with different suffix in modular (#34077)

* add support for different suffix in modular

* add dummy example, pull new changes for modular

* nide lines order change
This commit is contained in:
Yoni Gozlan
2024-10-15 14:55:09 +02:00
committed by GitHub
parent d314ce70bf
commit 65442718c4
3 changed files with 696 additions and 11 deletions

View File

@@ -204,7 +204,15 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
old_class_name: str = None,
new_class_name: str = None,
):
super().__init__()
self.old_name = old_name
self.new_name = new_name
@@ -220,6 +228,18 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns:
# In last recourse, when the suffix of the new class is not the same as the old class,
# and if the old and new classes start with the default name, we keep the default class name
# and replace the old suffix with the new one.
# Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`
# where a model extends another model, but is used for a different task.
if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name):
self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :]
def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
@@ -235,7 +255,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
def convert_to_camelcase(self, text):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
result = re.sub(
rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1
)
return result
@m.leave(m.Name() | m.SimpleString() | m.Comment())
@@ -249,9 +271,24 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
def find_classes_in_file(
module: cst.Module,
old_id="llama",
new_id="gemma",
given_old_name=None,
given_new_name=None,
old_class_name=None,
new_class_name=None,
):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
transformer = ReplaceNameTransformer(
old_id,
new_id,
given_old_name=given_old_name,
given_new_name=given_new_name,
old_class_name=old_class_name,
new_class_name=new_class_name,
)
new_module = module.visit(transformer)
wrapper = MetadataWrapper(new_module)
@@ -868,7 +905,7 @@ class ModularConverterTransformer(CSTTransformer):
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if list_dependencies == []:
if len(list_dependencies) == 0:
# so, maybe standard renaming did not work (the class name is different)
# we try with another renaming pattern
potential_given_name = get_new_part(class_name, super_class)
@@ -884,6 +921,30 @@ class ModularConverterTransformer(CSTTransformer):
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if len(list_dependencies) == 0:
# last recourse, if the suffix of the new class is different from the one of the super class
# e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection
# we try with another renaming pattern
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
self.model_name,
self.given_old_name,
self.given_new_name,
super_class,
class_name,
)
visited_module[super_file_name] = class_finder
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if len(list_dependencies) == 0:
raise ValueError(
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
)
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
@@ -917,12 +978,6 @@ class ModularConverterTransformer(CSTTransformer):
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
else:
raise ValueError(
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
)
# Now, if a class was defined without parents, we look for the name
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())