Add support for __all__ and potentilly deleting functions (#33859)
* Add support for __all__ and potentailly deleting functions
* updates
* update
* nits
* remove dummies
* fix warning
* fixup
* style
* update
* fixup
* skip copied from when # skip
* remove log
* bring dummies back
* fixup
* remove copied from
* fixup
* remove warnings from `make fix-copies`
* fix doc issues
* nits
* Better error message !
* add support for more flexible naming!
* style
* breaking style?
* fix super() renaming issues
* del not needed when you don't call super().__init__()
* style
* no more fmt on :)
* properly remove `self`
* fixup
* fix
* doc nits
* add some doc 🫡
This commit is contained in:
@@ -16,7 +16,8 @@ import argparse
|
||||
import glob
|
||||
import importlib
|
||||
import re
|
||||
from typing import Dict
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import libcst as cst
|
||||
from check_copies import run_ruff
|
||||
@@ -113,7 +114,11 @@ class ClassFinder(CSTVisitor):
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
|
||||
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
|
||||
):
|
||||
self.assignments[node.body[0].targets[0].target.value] = node
|
||||
if hasattr(node.body[0].targets[0].target, "value"):
|
||||
self.assignments[node.body[0].targets[0].target.value] = node
|
||||
else:
|
||||
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)):
|
||||
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
||||
self.imports[node.body[0].names] = node
|
||||
|
||||
@@ -217,11 +222,21 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
|
||||
return compiled_regex.sub(replace, text)
|
||||
|
||||
def convert_to_camelcase(self, text):
|
||||
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
|
||||
return result
|
||||
|
||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||
def replace_name(self, original_node, updated_node):
|
||||
if re.findall(r"# Copied from", updated_node.value):
|
||||
return cst.RemoveFromParent()
|
||||
update = self.preserve_case_replace(updated_node.value)
|
||||
return updated_node.with_changes(value=update)
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
|
||||
|
||||
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||
@@ -251,6 +266,63 @@ def SUPER_CALL_NODE(func_name):
|
||||
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||
|
||||
|
||||
def is_call_to_super(node, func_name):
|
||||
return m.matches(
|
||||
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
|
||||
)
|
||||
|
||||
|
||||
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
|
||||
class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
def __init__(self, all_bases: Set[str]):
|
||||
self.all_bases = all_bases
|
||||
|
||||
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
||||
# Handle ClassB.call_to_method
|
||||
if (
|
||||
isinstance(original_node.value, cst.Name)
|
||||
and original_node.value.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(
|
||||
value=cst.Call(cst.Name("super")),
|
||||
)
|
||||
# Handle ClassB().call_to_method
|
||||
elif (
|
||||
isinstance(original_node.value, cst.Call)
|
||||
and isinstance(original_node.value.func, cst.Name)
|
||||
and original_node.value.func.value in self.all_bases
|
||||
and isinstance(original_node.attr, cst.Name)
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
|
||||
return updated_node
|
||||
|
||||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
||||
if isinstance(original_node.func, cst.Attribute) and (
|
||||
# Match ClassB().func_a(...)
|
||||
(
|
||||
isinstance(original_node.func.value, cst.Call)
|
||||
and isinstance(original_node.func.value.func, cst.Name)
|
||||
and original_node.func.value.func.value in self.all_bases
|
||||
)
|
||||
or
|
||||
# Match ClassB.func_a(...)
|
||||
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
|
||||
):
|
||||
# Check if the first argument is 'self', and remove it
|
||||
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
||||
# Create the new argument list without 'self'
|
||||
new_args = updated_node.args[1:]
|
||||
else:
|
||||
new_args = updated_node.args
|
||||
|
||||
return updated_node.with_changes(args=new_args)
|
||||
return updated_node
|
||||
|
||||
|
||||
def get_docstring_indent(docstring):
|
||||
# Match the first line after the opening triple quotes
|
||||
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||
@@ -263,7 +335,7 @@ def get_docstring_indent(docstring):
|
||||
def merge_docstrings(original_docstring, updated_docstring):
|
||||
# indent_level = get_docstring_indent(updated_docstring)
|
||||
original_level = get_docstring_indent(original_docstring)
|
||||
if " Args:\n " not in updated_docstring:
|
||||
if not re.findall(r"\n\s*Args:\n", updated_docstring):
|
||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||
parts = original_docstring.split("```")
|
||||
if "```" in updated_docstring and len(parts) > 1:
|
||||
@@ -292,13 +364,15 @@ def merge_docstrings(original_docstring, updated_docstring):
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
|
||||
self.python_module = python_module
|
||||
self.original_methods = original_methods
|
||||
self.updated_methods = updated_methods
|
||||
self.all_assign_target = {}
|
||||
self.deleted_targets = {} # child node can delete some arguments
|
||||
self.class_name = class_name
|
||||
self.all_bases = all_bases or []
|
||||
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
|
||||
|
||||
def update_body(self, existing_body, new_statements):
|
||||
"""
|
||||
@@ -356,18 +430,14 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||
new_body = []
|
||||
has_super_call = False
|
||||
for idx, expr in enumerate(node.body):
|
||||
if m.matches(
|
||||
expr,
|
||||
m.SimpleStatementLine(
|
||||
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
|
||||
),
|
||||
):
|
||||
if idx != 0 and func_name == "__init__":
|
||||
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
|
||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
||||
|
||||
for expr in node.body:
|
||||
if is_call_to_super(expr, func_name):
|
||||
has_super_call = True
|
||||
elif m.matches(expr, DOCSTRING_NODE):
|
||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
||||
else:
|
||||
expr = expr.visit(self.transformer)
|
||||
if m.matches(expr, DOCSTRING_NODE):
|
||||
self.has_docstring = True
|
||||
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||
@@ -406,15 +476,17 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
return updated_node
|
||||
|
||||
|
||||
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
|
||||
def replace_call_to_super(
|
||||
class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
|
||||
):
|
||||
"""
|
||||
Given the `class_name`, the `updated_node`'s call to super are unpacked.
|
||||
|
||||
| ```python | | ```python
|
||||
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
||||
| def __init__(self): | | def __init__(self):
|
||||
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
|
||||
| super().__init__() | | super().__init__(config)
|
||||
Going from: | super().__init__() | to: | super().__init__(config)
|
||||
| self.dropout = 0.2 | | self.dropout = 0.2
|
||||
| ``` | | self.padding_idx = config.pad_token_id
|
||||
| self.vocab_size = config.vocab_size
|
||||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
@@ -453,7 +525,14 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
new_params = new_params.with_changes(
|
||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||
)
|
||||
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
||||
if not re.match(
|
||||
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
||||
class_finder.python_module.code_for_node(updated_methods[name]),
|
||||
):
|
||||
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
||||
else:
|
||||
continue
|
||||
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
@@ -492,7 +571,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(
|
||||
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
|
||||
SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
|
||||
)
|
||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||
|
||||
@@ -508,6 +587,31 @@ TYPE_TO_FILE_TYPE = {
|
||||
}
|
||||
|
||||
|
||||
def get_new_part(class_name, base_class):
|
||||
"""
|
||||
When `MyClassNameAttention` inherits from `MistralAttention`, we need
|
||||
to process the name to properly find dependencies.
|
||||
|
||||
Here we take what is the same (Attention) and what is different
|
||||
when finding the dependencies.
|
||||
"""
|
||||
common_suffix_len = 0
|
||||
for i in range(1, min(len(class_name), len(base_class)) + 1):
|
||||
if class_name[-i] == base_class[-i]:
|
||||
common_suffix_len += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if common_suffix_len > 0:
|
||||
new_part = class_name[:-common_suffix_len]
|
||||
else:
|
||||
new_part = class_name
|
||||
|
||||
# Convert the remaining new part to snake_case
|
||||
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
|
||||
return snake_case
|
||||
|
||||
|
||||
class ModularConverterTransformer(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
@@ -538,6 +642,7 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
}
|
||||
self.match_patterns = "|".join(self.files.keys())
|
||||
self.all_definitions = {}
|
||||
self.class_to_file_type = {}
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||
@@ -630,13 +735,33 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
self.given_new_name,
|
||||
)
|
||||
visited_module[super_file_name] = class_finder
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
else: # we are re-using the previously parsed data
|
||||
class_finder = visited_module[super_file_name]
|
||||
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
if list_dependencies == []:
|
||||
# so, maybe standard renaming did not work (the class name is different)
|
||||
# we try with another renaming pattern
|
||||
potential_given_name = get_new_part(class_name, super_class)
|
||||
del visited_module[super_file_name]
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name],
|
||||
model_name,
|
||||
potential_given_name,
|
||||
self.model_name,
|
||||
potential_given_name,
|
||||
)
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||
}
|
||||
|
||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||
start_insert_idx = self.global_scope_index
|
||||
@@ -668,10 +793,12 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
self.inserted_deps.append(dependency)
|
||||
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
|
||||
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||
)
|
||||
|
||||
# Now, if a class was defined without parents, we look for the name
|
||||
@@ -679,8 +806,10 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
match = re.search(rf"({match_pattern})$", class_name)
|
||||
if match:
|
||||
key = TYPE_TO_FILE_TYPE[match.group(1)]
|
||||
self.class_to_file_type[class_name] = key
|
||||
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
else:
|
||||
self.class_to_file_type[class_name] = "modeling"
|
||||
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
|
||||
@@ -690,14 +819,37 @@ class ModularConverterTransformer(CSTTransformer):
|
||||
self.all_definitions[node.name.value] = node
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> None:
|
||||
# Check if the assignment target is '__all__'
|
||||
if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__":
|
||||
if isinstance(node.value, cst.List):
|
||||
# Extract the elements from the list
|
||||
all_all_to_add = defaultdict(list)
|
||||
for elt in node.value.elements:
|
||||
if isinstance(elt.value, cst.SimpleString):
|
||||
# Remove quotes and add the string to the elements list
|
||||
class_name = elt.value.value
|
||||
file = self.class_to_file_type[
|
||||
elt.value.evaluated_value
|
||||
] # evaluated value give the content of the string
|
||||
all_all_to_add[file] += [class_name]
|
||||
for f_type, new_alls in all_all_to_add.items():
|
||||
updated_node = node.with_changes(
|
||||
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
|
||||
)
|
||||
self.files[f_type][class_name] = {
|
||||
"insert_idx": self.global_scope_index + 100,
|
||||
"node": updated_node,
|
||||
}
|
||||
|
||||
def leave_If(self, original_node, node):
|
||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||
if m.matches(parent_node, m.Module()):
|
||||
full_statement = self.python_module.code_for_node(original_node.test)
|
||||
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||
self.all_safe_imports.append(node)
|
||||
elif full_statement not in self.new_body:
|
||||
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
|
||||
elif full_statement not in self.all_imports:
|
||||
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
|
||||
return node
|
||||
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
@@ -764,7 +916,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["examples/modular-transformers/modular_dummy.py"],
|
||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
||||
@@ -373,7 +373,6 @@ src/transformers/data/processors/squad.py
|
||||
src/transformers/data/processors/utils.py
|
||||
src/transformers/data/processors/xnli.py
|
||||
src/transformers/debug_utils.py
|
||||
src/transformers/deepspeed.py
|
||||
src/transformers/dependency_versions_check.py
|
||||
src/transformers/dependency_versions_table.py
|
||||
src/transformers/dynamic_module_utils.py
|
||||
|
||||
Reference in New Issue
Block a user