Refactor StarCoder2 using modular (#34015)

* Create modular_starcoder2.py

* Update modular_starcoder2.py

* update

* finalize modular

* revert # no-unravel

* Add support

* style

* Update modular_model_converter.py

* update docstring
This commit is contained in:
Cyril Vallez
2024-11-21 14:52:39 +01:00
committed by GitHub
parent 18871599c9
commit 4e90b99ed9
3 changed files with 643 additions and 66 deletions

View File

@@ -145,45 +145,69 @@ def is_call_to_super(node, func_name):
)
def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None:
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
successive value of an Attribute are not Name nodes, return `None`."""
if m.matches(node, m.Name()):
return node.value
elif m.matches(node, m.Attribute()):
if not m.matches(node.attr, m.Name()):
return None
name = node.attr.value
new_node = node.value
while m.matches(new_node, m.Attribute()):
if not m.matches(new_node.attr, m.Name()):
return None
name = new_node.attr.value + "." + name
new_node = new_node.value
if not m.matches(new_node, m.Name()):
return None
return new_node.value + "." + name
return None
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
# Handle ClassB.call_to_method or module.classB.call_to_method
if (
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
m.matches(original_node.value, m.Name() | m.Attribute())
and get_full_attribute_name(original_node.value) in self.all_bases
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
# Handle ClassB().call_to_method or module.ClassB().call_to_method
elif (
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and m.matches(original_node.value.func, m.Name() | m.Attribute())
and get_full_attribute_name(original_node.value.func) in self.all_bases
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node.with_changes(value=cst.Call(cst.Name("super")))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
# Match ClassB().func_a(...) or module
(
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
and m.matches(original_node.func.value.func, m.Name() | m.Attribute())
and get_full_attribute_name(original_node.func.value.func) in self.all_bases
)
or
# Match ClassB.func_a(...)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
(
m.matches(original_node.func.value, m.Name() | m.Attribute())
and get_full_attribute_name(original_node.func.value) in self.all_bases
)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
@@ -860,7 +884,9 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
| self.post_init()
| ```
"""
all_bases = [k.value.value for k in class_node.bases]
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
if any(base is None for base in all_bases):
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
original_node = mapper.classes[renamed_super_class]
original_methods = {
@@ -1496,7 +1522,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma2/modular_gemma2.py"],
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)