[modular] Simplify logic and docstring handling (#39185)
* simplify a lot * Update modular_model_converter.py * finalize * remove outdated functions * apply it * and examples
This commit is contained in:
@@ -249,90 +249,13 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
return updated_node
|
||||
|
||||
|
||||
def get_docstring_indent(docstring):
|
||||
# Match the first line after the opening triple quotes
|
||||
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||
if match:
|
||||
# Return the indentation spaces captured
|
||||
return len(match.group(1))
|
||||
return 0
|
||||
|
||||
|
||||
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
|
||||
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
|
||||
be merged with the existing old one.
|
||||
"""
|
||||
# libcst returns the docstrinbgs with literal `r"""` quotes in front
|
||||
new_docstring = new_docstring.split('"""', 1)[1]
|
||||
# The docstring contains Args definition, so it is self-contained
|
||||
if re.search(r"\n\s*Args:\n", new_docstring):
|
||||
return True
|
||||
elif re.search(r"\n\s*Args:\n", original_docstring):
|
||||
return False
|
||||
# Check if the docstring contains args docstring (meaning it is self contained):
|
||||
param_pattern = re.compile(
|
||||
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
|
||||
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
match_object = param_pattern.search(new_docstring)
|
||||
if match_object is not None:
|
||||
return True
|
||||
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
|
||||
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
|
||||
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
|
||||
if match_object is not None:
|
||||
full_indent = match_object.group(1)
|
||||
striped_doc = new_docstring.strip("\n")
|
||||
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def merge_docstrings(original_docstring, updated_docstring):
|
||||
original_level = get_docstring_indent(original_docstring)
|
||||
if not is_full_docstring(original_docstring, updated_docstring, original_level):
|
||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||
parts = original_docstring.split("```")
|
||||
if "```" in updated_docstring and len(parts) > 1:
|
||||
updated_docstring = updated_docstring.lstrip('r"')
|
||||
new_parts = updated_docstring.split("```")
|
||||
if len(new_parts) != 3:
|
||||
raise ValueError("There should only be one example, and it should have opening and closing '```'")
|
||||
parts[1] = new_parts[1]
|
||||
updated_docstring = "".join(
|
||||
[
|
||||
f"\n{original_level * ' '}```",
|
||||
parts[1],
|
||||
"```",
|
||||
parts[2],
|
||||
]
|
||||
)
|
||||
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
|
||||
new_start_docstring = new_parts[0].rstrip(" \n")
|
||||
docstring_opening += '"""'
|
||||
if new_start_docstring.startswith(original_start_docstring):
|
||||
updated_docstring = new_start_docstring + "\n" + updated_docstring
|
||||
elif original_start_docstring.endswith(new_start_docstring):
|
||||
updated_docstring = original_start_docstring + "\n" + updated_docstring
|
||||
else:
|
||||
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
|
||||
updated_docstring = docstring_opening + updated_docstring
|
||||
elif updated_docstring not in original_docstring:
|
||||
# add tabulation if we are at the lowest level.
|
||||
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
||||
updated_docstring = updated_docstring.replace("\n ", "\n ")
|
||||
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
|
||||
return updated_docstring
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
|
||||
def __init__(self, python_module: cst.Module, original_modeling_methods, modular_methods, all_bases=None):
|
||||
self.python_module = python_module
|
||||
self.original_methods = original_methods
|
||||
self.updated_methods = updated_methods
|
||||
self.original_modeling_methods = original_modeling_methods
|
||||
self.modular_methods = modular_methods
|
||||
self.all_assign_target = {}
|
||||
self.deleted_targets = {} # child node can delete some arguments
|
||||
self.all_bases = all_bases or []
|
||||
@@ -414,53 +337,39 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
break
|
||||
return new_body
|
||||
|
||||
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
||||
def replace_super_calls(self, node: cst.BaseSuite, func_name: str) -> cst.BaseSuite:
|
||||
"""Updates the body of the input `node`'s `func_name` function by replacing calls
|
||||
to super().func_name() with the source code of the parent class' `func_name`.
|
||||
It keeps everything that is defined before `super().func_name()`.
|
||||
"""
|
||||
self.has_docstring = False
|
||||
parent_has_docstring = False
|
||||
if func_name in self.original_methods:
|
||||
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||
new_body = []
|
||||
has_super_call = False
|
||||
modular_node_body = node.body
|
||||
|
||||
for i, expr in enumerate(node.body):
|
||||
for i, expr in enumerate(modular_node_body):
|
||||
if is_call_to_super(expr, func_name):
|
||||
has_super_call = True
|
||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
|
||||
original_modeling_method_body = self.original_modeling_methods[func_name].body.body
|
||||
new_body.extend(self.update_body(original_modeling_method_body, modular_node_body[i + 1 :]))
|
||||
new_body = self._fix_init_location(new_body)
|
||||
return node.with_changes(body=new_body)
|
||||
else:
|
||||
expr = expr.visit(self.transformer)
|
||||
if m.matches(expr, DOCSTRING_NODE):
|
||||
self.has_docstring = True
|
||||
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||
updated_docstring = expr.body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
|
||||
else:
|
||||
new_node = [expr]
|
||||
new_body.extend(new_node)
|
||||
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
|
||||
if not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])):
|
||||
new_body.append(expr)
|
||||
if not self.has_docstring and parent_has_docstring:
|
||||
new_body = [self.original_methods[func_name].body.body[0]] + new_body
|
||||
|
||||
return node.with_changes(body=new_body)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
if updated_node.name.value in self.updated_methods:
|
||||
name = updated_node.name.value
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
name = updated_node.name.value
|
||||
if name in self.modular_methods:
|
||||
new_body = self.replace_super_calls(updated_node.body, name)
|
||||
return updated_node.with_changes(body=new_body, params=updated_node.params)
|
||||
return updated_node
|
||||
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return:
|
||||
""" "When a return statement is reached, it is replaced with the unrolled super code"""
|
||||
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
|
||||
func_def = self.get_metadata(ParentNodeProvider, original_node)
|
||||
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
|
||||
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_modeling_methods:
|
||||
updated_return_value = updated_node.value.with_changes(
|
||||
args=[
|
||||
cst.Arg(
|
||||
@@ -979,55 +888,52 @@ def common_partial_suffix(str1: str, str2: str) -> str:
|
||||
|
||||
|
||||
def replace_class_node(
|
||||
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
):
|
||||
mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
) -> cst.ClassDef:
|
||||
"""
|
||||
Replace a class node which inherits from another modeling class. This function works in the following way:
|
||||
- start from the base class node of the inherited class (a cst.Node)
|
||||
- replace all methods of the base node with the methods defined in the child class
|
||||
- append all new methods defined in the child class
|
||||
- start from the methods and class attributes of the original modeling code node, and replace their definition
|
||||
if overriden in the modular
|
||||
- append all new methods and class attributes defined in the child class
|
||||
- all potential method/class docstrings and decorators use the ones found in modular if any, else in original modeling
|
||||
- replace all calls to super() with the unravelled code
|
||||
|
||||
| ```python | | ```python
|
||||
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
||||
| def __init__(self): | | def __init__(self):
|
||||
Going from: | super().__init__() | to: | super().__init__(config)
|
||||
| self.dropout = 0.2 | | self.dropout = 0.2
|
||||
| ``` | | self.padding_idx = config.pad_token_id
|
||||
| self.vocab_size = config.vocab_size
|
||||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
| self.layers = nn.ModuleList(
|
||||
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
| )
|
||||
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
| self.gradient_checkpointing = False
|
||||
| # Initialize weights and apply final processing
|
||||
| self.post_init()
|
||||
| ```
|
||||
"""
|
||||
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
|
||||
if any(base is None for base in all_bases):
|
||||
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
||||
Args:
|
||||
mapper (`ModelFileMapper`):
|
||||
The mapper corresponding to the visited file from which the modular class node inherits.
|
||||
modular_class_node (`cst.ClassDef`):
|
||||
The class node as found in the modular file.
|
||||
renamed_super_class (`str`):
|
||||
The name of the class from which `modular_class_node` inherits after automatic renaming.
|
||||
original_super_class (`str`):
|
||||
The name of the class from which `modular_class_node` inherits before automatic renaming.
|
||||
|
||||
original_node = mapper.classes[renamed_super_class]
|
||||
Returns:
|
||||
A new class node corresponding to the modular definition.
|
||||
"""
|
||||
all_bases = [get_full_attribute_name(k.value) for k in modular_class_node.bases]
|
||||
if any(base is None for base in all_bases):
|
||||
raise ValueError(f"Could not parse the name of the bases for {modular_class_node.name.value}")
|
||||
|
||||
original_modeling_node = mapper.classes[renamed_super_class]
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
new_name = class_node.name
|
||||
new_class_name = modular_class_node.name
|
||||
|
||||
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
|
||||
if new_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
|
||||
if new_class_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_class_name.value, renamed_super_class)
|
||||
# Note that this works even without common prefix, in which case it does not replace anything
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_node])
|
||||
original_node = temp_module.visit(
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_class_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_modeling_node])
|
||||
original_modeling_node = temp_module.visit(
|
||||
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
|
||||
).body[0]
|
||||
|
||||
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
|
||||
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
|
||||
additional_bases = [base for base in all_bases if base != original_super_class]
|
||||
new_bases = []
|
||||
for original_base in original_node.bases:
|
||||
new_class_bases = []
|
||||
for original_base in original_modeling_node.bases:
|
||||
new_base = original_base
|
||||
# we only potentially switch base for Name-based bases, not Attribute
|
||||
if m.matches(original_base.value, m.Name()):
|
||||
@@ -1038,106 +944,125 @@ def replace_class_node(
|
||||
new_name_node = original_base.value.with_changes(value=additional_base_name)
|
||||
new_base = original_base.with_changes(value=new_name_node)
|
||||
break
|
||||
new_bases.append(new_base)
|
||||
new_class_bases.append(new_base)
|
||||
|
||||
original_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
|
||||
for f in original_node.body.body
|
||||
# Use class decorators redefined in modular file if any
|
||||
new_class_decorators = (
|
||||
modular_class_node.decorators if len(modular_class_node.decorators) > 0 else original_modeling_node.decorators
|
||||
)
|
||||
|
||||
# Compute new class docstring
|
||||
original_modeling_docstring = [
|
||||
node for node in original_modeling_node.body.body if m.matches(node, DOCSTRING_NODE)
|
||||
]
|
||||
modular_docstring = [node for node in modular_class_node.body.body if m.matches(node, DOCSTRING_NODE)]
|
||||
# Use class docstring in modular if any, else original modeling code docstring
|
||||
new_class_docstring = modular_docstring if len(modular_docstring) > 0 else original_modeling_docstring
|
||||
|
||||
# Compute new class attributes
|
||||
original_modeling_class_attributes = {
|
||||
node.body[0].targets[0].target.value: node
|
||||
for node in original_modeling_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
|
||||
}
|
||||
updated_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
|
||||
original_modeling_class_attributes.update(
|
||||
{
|
||||
node.body[0].target.value: node
|
||||
for node in original_modeling_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
|
||||
}
|
||||
)
|
||||
modular_class_attributes = {
|
||||
node.body[0].targets[0].target.value: node
|
||||
for node in modular_class_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
|
||||
}
|
||||
end_meth = []
|
||||
modular_class_attributes.update(
|
||||
{
|
||||
node.body[0].target.value: node
|
||||
for node in modular_class_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
|
||||
}
|
||||
)
|
||||
# Use all original modeling attributes, and potentially override some with values in the modular
|
||||
new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values())
|
||||
|
||||
assign_targets = {}
|
||||
docstring_node = []
|
||||
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
||||
for func in original_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
||||
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
|
||||
new_params = updated_methods[name].params
|
||||
# Replace the method in the replacement class, preserving decorators
|
||||
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
||||
if kwarg_name and kwarg_name.name.value == "super_kwargs":
|
||||
parent_params = {k.name.value: k for k in func.params.params}
|
||||
parent_params.update({k.name.value: k for k in new_params.params[1:]})
|
||||
new_params = new_params.with_changes(
|
||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||
)
|
||||
# Keep decorators in `modular_xxx.py` if any, else original decorators
|
||||
new_decorators = (
|
||||
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
|
||||
)
|
||||
original_modeling_methods = {
|
||||
node.name.value: node for node in original_modeling_node.body.body if m.matches(node, m.FunctionDef())
|
||||
}
|
||||
modular_methods = {
|
||||
node.name.value: node for node in modular_class_node.body.body if m.matches(node, m.FunctionDef())
|
||||
}
|
||||
|
||||
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
|
||||
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
|
||||
new_class_methods = []
|
||||
# Iterate over the methods of the original modeling code, and add them to the list of methods to add
|
||||
for name, node in original_modeling_methods.items():
|
||||
# If the method was redefined in modular, make appropriate changes to the node
|
||||
if name in modular_methods:
|
||||
# Get the corresponding method node in modular
|
||||
modular_node = modular_methods[name]
|
||||
|
||||
if not re.match(
|
||||
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
||||
mapper.python_module.code_for_node(updated_methods[name]),
|
||||
):
|
||||
func = func.with_changes(
|
||||
body=updated_methods[name].body,
|
||||
params=new_params,
|
||||
decorators=new_decorators,
|
||||
returns=new_return_annotation,
|
||||
)
|
||||
else:
|
||||
# If we match the pattern, we should avoid inheriting the method
|
||||
if re.match(r"\ndef .*\(.*\):\n raise.*Error\(.*", mapper.python_module.code_for_node(modular_node)):
|
||||
continue
|
||||
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, DOCSTRING_NODE):
|
||||
docstring_node = [func]
|
||||
else:
|
||||
end_meth.append(func)
|
||||
# Compute new method docstring
|
||||
modeling_docstring = [node_ for node_ in node.body.body if m.matches(node_, DOCSTRING_NODE)]
|
||||
modular_docstring = [node_ for node_ in modular_node.body.body if m.matches(node_, DOCSTRING_NODE)]
|
||||
# Use method docstring in modular if any, else original modeling code docstring
|
||||
new_body = (
|
||||
modular_node.body.body
|
||||
if len(modular_docstring) > 0
|
||||
else modeling_docstring + list(modular_node.body.body)
|
||||
)
|
||||
new_body = modular_node.body.with_changes(body=new_body)
|
||||
|
||||
# Use arguments as defined in the modular
|
||||
new_params = modular_node.params
|
||||
|
||||
# If using the `**super_kwargs` syntax in modular, merge any existing modular arg with all the original modeling ones
|
||||
kwarg_name = getattr(modular_node.params, "star_kwarg", None)
|
||||
if kwarg_name and kwarg_name.name.value == "super_kwargs":
|
||||
original_modeling_params = {k.name.value: k for k in node.params.params}
|
||||
modular_params = {k.name.value: k for k in new_params.params[1:]}
|
||||
new_param_list = list({**original_modeling_params, **modular_params}.values())
|
||||
new_params = new_params.with_changes(params=new_param_list, star_kwarg=node.params.star_kwarg)
|
||||
|
||||
# Keep decorators in modular if any, else original decorators
|
||||
new_decorators = modular_node.decorators if len(modular_node.decorators) > 0 else node.decorators
|
||||
|
||||
# Keep return annotation in modular if any, else original return annotation
|
||||
new_return_annotation = modular_node.returns if modular_node.returns else node.returns
|
||||
|
||||
# Update the method node
|
||||
node = node.with_changes(
|
||||
body=new_body,
|
||||
params=new_params,
|
||||
decorators=new_decorators,
|
||||
returns=new_return_annotation,
|
||||
)
|
||||
|
||||
new_class_methods.append(node)
|
||||
|
||||
# Port new methods that are defined only in modular-file and append at the end
|
||||
for func in class_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
||||
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||
# Extract the original docstring
|
||||
updated_docstring = func.body[0].value.value
|
||||
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
|
||||
docstring_node = [
|
||||
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
|
||||
]
|
||||
else:
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||
end_meth.append(func)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
# TODO we only use single assign might cause issues
|
||||
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
end_meth = docstring_node + list(assign_targets.values()) + end_meth
|
||||
for name, node in modular_methods.items():
|
||||
if name not in original_modeling_methods:
|
||||
new_class_methods.append(node)
|
||||
|
||||
# Replace the calls to `super()` with the unrolled code
|
||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
# Recreate the whole new class body
|
||||
new_class_body = new_class_docstring + new_class_attributes + new_class_methods
|
||||
|
||||
# Replace the calls to `super()` of the redefined modular methods with the unrolled code
|
||||
result_node = original_modeling_node.with_changes(body=cst.IndentedBlock(body=new_class_body))
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(
|
||||
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
|
||||
SuperTransformer(temp_module, original_modeling_methods, modular_methods, all_bases)
|
||||
)
|
||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||
new_class_body = new_replacement_class.body[0].body # get the indented block
|
||||
|
||||
# Use decorators redefined in `modular_xxx.py` if any
|
||||
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
|
||||
|
||||
return original_node.with_changes(
|
||||
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
|
||||
return original_modeling_node.with_changes(
|
||||
body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user