Add support for __all__ and potentilly deleting functions (#33859)

* Add support for __all__ and potentailly deleting functions

* updates

* update

* nits

* remove dummies

* fix warning

* fixup

* style

* update

* fixup

* skip copied from when # skip

* remove log

* bring dummies back

* fixup

* remove copied from

* fixup

* remove warnings from `make fix-copies`

* fix doc issues

* nits

* Better error message !

* add support for more flexible naming!

* style

* breaking style?

* fix super() renaming issues

* del not needed when you don't call super().__init__()

* style

* no more fmt on :)

* properly remove `self`

* fixup

* fix

* doc nits

* add some doc 🫡
This commit is contained in:
Arthur
2024-10-08 10:19:17 +02:00
committed by GitHub
parent bead0fa8dc
commit a3add29097
15 changed files with 477 additions and 149 deletions

View File

@@ -16,7 +16,8 @@ import argparse
import glob
import importlib
import re
from typing import Dict
from collections import defaultdict
from typing import Dict, List, Set
import libcst as cst
from check_copies import run_ruff
@@ -113,7 +114,11 @@ class ClassFinder(CSTVisitor):
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
):
self.assignments[node.body[0].targets[0].target.value] = node
if hasattr(node.body[0].targets[0].target, "value"):
self.assignments[node.body[0].targets[0].target.value] = node
else:
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)):
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node
@@ -217,11 +222,21 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return compiled_regex.sub(replace, text)
def convert_to_camelcase(self, text):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
return result
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
"""Helper function to rename and then parse a source file using the ClassFinder"""
@@ -251,6 +266,63 @@ def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def is_call_to_super(node, func_name):
return m.matches(
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
)
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
# Create the new argument list without 'self'
new_args = updated_node.args[1:]
else:
new_args = updated_node.args
return updated_node.with_changes(args=new_args)
return updated_node
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
@@ -263,7 +335,7 @@ def get_docstring_indent(docstring):
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if " Args:\n " not in updated_docstring:
if not re.findall(r"\n\s*Args:\n", updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
@@ -292,13 +364,15 @@ def merge_docstrings(original_docstring, updated_docstring):
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.class_name = class_name
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
def update_body(self, existing_body, new_statements):
"""
@@ -356,18 +430,14 @@ class SuperTransformer(cst.CSTTransformer):
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
for idx, expr in enumerate(node.body):
if m.matches(
expr,
m.SimpleStatementLine(
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
),
):
if idx != 0 and func_name == "__init__":
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
for expr in node.body:
if is_call_to_super(expr, func_name):
has_super_call = True
elif m.matches(expr, DOCSTRING_NODE):
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
@@ -406,15 +476,17 @@ class SuperTransformer(cst.CSTTransformer):
return updated_node
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
def replace_call_to_super(
class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
):
"""
Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
| super().__init__() | | super().__init__(config)
Going from: | super().__init__() | to: | super().__init__(config)
| self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
@@ -453,7 +525,14 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
func = func.with_changes(body=updated_methods[name].body, params=new_params)
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
class_finder.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
@@ -492,7 +571,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
@@ -508,6 +587,31 @@ TYPE_TO_FILE_TYPE = {
}
def get_new_part(class_name, base_class):
"""
When `MyClassNameAttention` inherits from `MistralAttention`, we need
to process the name to properly find dependencies.
Here we take what is the same (Attention) and what is different
when finding the dependencies.
"""
common_suffix_len = 0
for i in range(1, min(len(class_name), len(base_class)) + 1):
if class_name[-i] == base_class[-i]:
common_suffix_len += 1
else:
break
if common_suffix_len > 0:
new_part = class_name[:-common_suffix_len]
else:
new_part = class_name
# Convert the remaining new part to snake_case
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
return snake_case
class ModularConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
@@ -538,6 +642,7 @@ class ModularConverterTransformer(CSTTransformer):
}
self.match_patterns = "|".join(self.files.keys())
self.all_definitions = {}
self.class_to_file_type = {}
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
@@ -630,13 +735,33 @@ class ModularConverterTransformer(CSTTransformer):
self.given_new_name,
)
visited_module[super_file_name] = class_finder
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
else: # we are re-using the previously parsed data
class_finder = visited_module[super_file_name]
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
if list_dependencies == []:
# so, maybe standard renaming did not work (the class name is different)
# we try with another renaming pattern
potential_given_name = get_new_part(class_name, super_class)
del visited_module[super_file_name]
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
potential_given_name,
self.model_name,
potential_given_name,
)
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
@@ -668,10 +793,12 @@ class ModularConverterTransformer(CSTTransformer):
self.inserted_deps.append(dependency)
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
else:
raise ValueError(
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
)
# Now, if a class was defined without parents, we look for the name
@@ -679,8 +806,10 @@ class ModularConverterTransformer(CSTTransformer):
match = re.search(rf"({match_pattern})$", class_name)
if match:
key = TYPE_TO_FILE_TYPE[match.group(1)]
self.class_to_file_type[class_name] = key
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
else:
self.class_to_file_type[class_name] = "modeling"
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
@@ -690,14 +819,37 @@ class ModularConverterTransformer(CSTTransformer):
self.all_definitions[node.name.value] = node
return node
def visit_Assign(self, node: cst.Assign) -> None:
# Check if the assignment target is '__all__'
if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__":
if isinstance(node.value, cst.List):
# Extract the elements from the list
all_all_to_add = defaultdict(list)
for elt in node.value.elements:
if isinstance(elt.value, cst.SimpleString):
# Remove quotes and add the string to the elements list
class_name = elt.value.value
file = self.class_to_file_type[
elt.value.evaluated_value
] # evaluated value give the content of the string
all_all_to_add[file] += [class_name]
for f_type, new_alls in all_all_to_add.items():
updated_node = node.with_changes(
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
)
self.files[f_type][class_name] = {
"insert_idx": self.global_scope_index + 100,
"node": updated_node,
}
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_safe_imports.append(node)
elif full_statement not in self.new_body:
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
elif full_statement not in self.all_imports:
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
return node
def leave_Module(self, original_node: cst.Assign, node):
@@ -764,7 +916,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["examples/modular-transformers/modular_dummy.py"],
default=["src/transformers/models/gemma/modular_gemma.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)

View File

@@ -373,7 +373,6 @@ src/transformers/data/processors/squad.py
src/transformers/data/processors/utils.py
src/transformers/data/processors/xnli.py
src/transformers/debug_utils.py
src/transformers/deepspeed.py
src/transformers/dependency_versions_check.py
src/transformers/dependency_versions_table.py
src/transformers/dynamic_module_utils.py