Refactor StarCoder2 using modular (#34015)
* Create modular_starcoder2.py * Update modular_starcoder2.py * update * finalize modular * revert # no-unravel * Add support * style * Update modular_model_converter.py * update docstring
This commit is contained in:
@@ -145,45 +145,69 @@ def is_call_to_super(node, func_name):
|
||||
)
|
||||
|
||||
|
||||
def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None:
|
||||
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
|
||||
successive value of an Attribute are not Name nodes, return `None`."""
|
||||
if m.matches(node, m.Name()):
|
||||
return node.value
|
||||
elif m.matches(node, m.Attribute()):
|
||||
if not m.matches(node.attr, m.Name()):
|
||||
return None
|
||||
name = node.attr.value
|
||||
new_node = node.value
|
||||
while m.matches(new_node, m.Attribute()):
|
||||
if not m.matches(new_node.attr, m.Name()):
|
||||
return None
|
||||
name = new_node.attr.value + "." + name
|
||||
new_node = new_node.value
|
||||
if not m.matches(new_node, m.Name()):
|
||||
return None
|
||||
return new_node.value + "." + name
|
||||
return None
|
||||
|
||||
|
||||
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
|
||||
class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
def __init__(self, all_bases: Set[str]):
|
||||
self.all_bases = all_bases
|
||||
|
||||
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
||||
# Handle ClassB.call_to_method
|
||||
# Handle ClassB.call_to_method or module.classB.call_to_method
|
||||
if (
|
||||
m.matches(original_node.value, m.Name())
|
||||
and original_node.value.value in self.all_bases
|
||||
m.matches(original_node.value, m.Name() | m.Attribute())
|
||||
and get_full_attribute_name(original_node.value) in self.all_bases
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(
|
||||
value=cst.Call(cst.Name("super")),
|
||||
)
|
||||
# Handle ClassB().call_to_method
|
||||
# Handle ClassB().call_to_method or module.ClassB().call_to_method
|
||||
elif (
|
||||
m.matches(original_node.value, m.Call())
|
||||
and m.matches(original_node.value.func, m.Name())
|
||||
and original_node.value.func.value in self.all_bases
|
||||
and m.matches(original_node.value.func, m.Name() | m.Attribute())
|
||||
and get_full_attribute_name(original_node.value.func) in self.all_bases
|
||||
and m.matches(original_node.attr, m.Name())
|
||||
):
|
||||
# Replace with super().call_to_method
|
||||
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
|
||||
return updated_node.with_changes(value=cst.Call(cst.Name("super")))
|
||||
return updated_node
|
||||
|
||||
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
||||
if m.matches(original_node.func, m.Attribute()) and (
|
||||
# Match ClassB().func_a(...)
|
||||
# Match ClassB().func_a(...) or module
|
||||
(
|
||||
m.matches(original_node.func.value, m.Call())
|
||||
and m.matches(original_node.func.value.func, m.Name())
|
||||
and original_node.func.value.func.value in self.all_bases
|
||||
and m.matches(original_node.func.value.func, m.Name() | m.Attribute())
|
||||
and get_full_attribute_name(original_node.func.value.func) in self.all_bases
|
||||
)
|
||||
or
|
||||
# Match ClassB.func_a(...)
|
||||
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
|
||||
(
|
||||
m.matches(original_node.func.value, m.Name() | m.Attribute())
|
||||
and get_full_attribute_name(original_node.func.value) in self.all_bases
|
||||
)
|
||||
):
|
||||
# Check if the first argument is 'self', and remove it
|
||||
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
||||
@@ -860,7 +884,9 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
| self.post_init()
|
||||
| ```
|
||||
"""
|
||||
all_bases = [k.value.value for k in class_node.bases]
|
||||
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
|
||||
if any(base is None for base in all_bases):
|
||||
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
||||
|
||||
original_node = mapper.classes[renamed_super_class]
|
||||
original_methods = {
|
||||
@@ -1496,7 +1522,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/gemma2/modular_gemma2.py"],
|
||||
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user