Modular fix (#34802)
* Modular fix * style * remove logger warning * Update modular_model_converter.py
This commit is contained in:
@@ -130,6 +130,16 @@ class MyNewModelConfig(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "my_new_model"
|
model_type = "my_new_model"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
# Default tensor parallel plan for base model `MyNewModelModel`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -33,6 +33,16 @@ class MyNewModel2Config(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "my_new_model2"
|
model_type = "my_new_model2"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
# Default tensor parallel plan for base model `MyNewModel2Model`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import math
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -150,25 +149,7 @@ class DummyMLP(nn.Module):
|
|||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.config.pretraining_tp > 1:
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
slice = self.intermediate_size // self.config.pretraining_tp
|
|
||||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
|
||||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
|
||||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
|
||||||
|
|
||||||
gate_proj = torch.cat(
|
|
||||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
|
||||||
)
|
|
||||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
|
||||||
|
|
||||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
|
||||||
down_proj = [
|
|
||||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
|
||||||
]
|
|
||||||
down_proj = sum(down_proj)
|
|
||||||
else:
|
|
||||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
@@ -264,31 +245,14 @@ class DummyAttention(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
query_states = self.q_proj(hidden_states)
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
key_states = self.k_proj(hidden_states)
|
||||||
query_slices = self.q_proj.weight.split(
|
value_states = self.v_proj(hidden_states)
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -330,12 +294,7 @@ class DummyAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
attn_output = self.o_proj(attn_output)
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
||||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
@@ -508,9 +467,10 @@ class DummySdpaAttention(DummyAttention):
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -794,7 +754,10 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.rotary_emb = DummyRotaryEmbedding(config=config)
|
self.rotary_emb = DummyRotaryEmbedding(config=config)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
if getattr(config, "pretraining_tp", 1) != 1:
|
||||||
|
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@@ -874,7 +837,7 @@ class DummyModel(DummyPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
|||||||
@@ -667,7 +667,10 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
[MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
if getattr(config, "pretraining_tp", 1) != 1:
|
||||||
|
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@@ -752,7 +755,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import math
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -150,25 +149,7 @@ class SuperMLP(nn.Module):
|
|||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.config.pretraining_tp > 1:
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
slice = self.intermediate_size // self.config.pretraining_tp
|
|
||||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
|
||||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
|
||||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
|
||||||
|
|
||||||
gate_proj = torch.cat(
|
|
||||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
|
||||||
)
|
|
||||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
|
||||||
|
|
||||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
|
||||||
down_proj = [
|
|
||||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
|
||||||
]
|
|
||||||
down_proj = sum(down_proj)
|
|
||||||
else:
|
|
||||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
@@ -264,31 +245,14 @@ class SuperAttention(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
query_states = self.q_proj(hidden_states)
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
key_states = self.k_proj(hidden_states)
|
||||||
query_slices = self.q_proj.weight.split(
|
value_states = self.v_proj(hidden_states)
|
||||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -330,12 +294,7 @@ class SuperAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
attn_output = self.o_proj(attn_output)
|
||||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
||||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
@@ -508,9 +467,10 @@ class SuperSdpaAttention(SuperAttention):
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -794,7 +754,10 @@ class SuperModel(SuperPreTrainedModel):
|
|||||||
)
|
)
|
||||||
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.rotary_emb = SuperRotaryEmbedding(config=config)
|
self.rotary_emb = SuperRotaryEmbedding(config=config)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
if getattr(config, "pretraining_tp", 1) != 1:
|
||||||
|
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|||||||
@@ -266,7 +266,6 @@ class SuperTransformer(cst.CSTTransformer):
|
|||||||
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
|
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
|
||||||
if target in self.deleted_targets:
|
if target in self.deleted_targets:
|
||||||
logger.warning(f"Deleted the assign for {target}")
|
|
||||||
continue
|
continue
|
||||||
if target in self.all_assign_target:
|
if target in self.all_assign_target:
|
||||||
stmt = self.all_assign_target[target]
|
stmt = self.all_assign_target[target]
|
||||||
@@ -773,6 +772,8 @@ class ModelFileMapper(ModuleMapper):
|
|||||||
self.object_dependency_mapping.update(
|
self.object_dependency_mapping.update(
|
||||||
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
|
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
|
||||||
)
|
)
|
||||||
|
# Add them to global nodes
|
||||||
|
self.global_nodes.update(self.functions)
|
||||||
|
|
||||||
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
||||||
"""Update the global nodes with the assignment from the modular file.
|
"""Update the global nodes with the assignment from the modular file.
|
||||||
@@ -786,6 +787,8 @@ class ModelFileMapper(ModuleMapper):
|
|||||||
self.assignments[assignment] = node
|
self.assignments[assignment] = node
|
||||||
if assignment in object_mapping:
|
if assignment in object_mapping:
|
||||||
self.object_dependency_mapping[assignment] = object_mapping[assignment]
|
self.object_dependency_mapping[assignment] = object_mapping[assignment]
|
||||||
|
# Add them to global nodes
|
||||||
|
self.global_nodes.update(self.assignments)
|
||||||
|
|
||||||
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
|
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
|
||||||
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
|
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
|
||||||
@@ -813,10 +816,7 @@ class ModelFileMapper(ModuleMapper):
|
|||||||
self._merge_classes(classes)
|
self._merge_classes(classes)
|
||||||
self.modular_file_start_lines = start_lines
|
self.modular_file_start_lines = start_lines
|
||||||
|
|
||||||
# Correctly re-set the global nodes at this point
|
# Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports
|
||||||
self.global_nodes.update(self.functions)
|
|
||||||
self.global_nodes.update(self.assignments)
|
|
||||||
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
|
|
||||||
self._restrict_dependencies_to_known_entities()
|
self._restrict_dependencies_to_known_entities()
|
||||||
# Create the global mapping of recursive dependencies for functions and assignments
|
# Create the global mapping of recursive dependencies for functions and assignments
|
||||||
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
||||||
@@ -1024,14 +1024,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
|
|||||||
import_ref_count[name] = ref_count
|
import_ref_count[name] = ref_count
|
||||||
|
|
||||||
imports_to_keep = []
|
imports_to_keep = []
|
||||||
|
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
|
||||||
for node in all_imports:
|
for node in all_imports:
|
||||||
if m.matches(node, m.If()): # handle safe imports
|
if m.matches(node, m.If()): # handle safe imports
|
||||||
new_statements = []
|
new_statements = []
|
||||||
for stmt_node in node.body.body:
|
for stmt_node in node.body.body:
|
||||||
append_new_import_node(stmt_node, unused_imports, new_statements)
|
append_new_import_node(stmt_node, unused_imports, new_statements)
|
||||||
|
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
|
||||||
if len(new_statements) > 0:
|
if len(new_statements) > 0:
|
||||||
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
|
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
|
||||||
imports_to_keep.append(new_node)
|
imports_to_keep.append(new_node)
|
||||||
|
existing_protected_statements.update({str(stmt) for stmt in new_statements})
|
||||||
else:
|
else:
|
||||||
append_new_import_node(node, unused_imports, imports_to_keep)
|
append_new_import_node(node, unused_imports, imports_to_keep)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user