Modular fix (#34802)

* Modular fix

* style

* remove logger warning

* Update modular_model_converter.py
This commit is contained in:
Cyril Vallez
2024-11-19 16:08:57 +01:00
committed by GitHub
parent ce1d328e3b
commit e3a5889ef0
6 changed files with 65 additions and 113 deletions

View File

@@ -130,6 +130,16 @@ class MyNewModelConfig(PretrainedConfig):
model_type = "my_new_model" model_type = "my_new_model"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MyNewModelModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__( def __init__(
self, self,

View File

@@ -33,6 +33,16 @@ class MyNewModel2Config(PretrainedConfig):
model_type = "my_new_model2" model_type = "my_new_model2"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MyNewModel2Model`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__( def __init__(
self, self,

View File

@@ -8,7 +8,6 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
@@ -150,25 +149,7 @@ class DummyMLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
if self.config.pretraining_tp > 1: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj return down_proj
@@ -264,31 +245,14 @@ class DummyAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1: query_states = self.q_proj(hidden_states)
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp key_states = self.k_proj(hidden_states)
query_slices = self.q_proj.weight.split( value_states = self.v_proj(hidden_states)
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = torch.cat(query_states, dim=-1) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
@@ -330,12 +294,7 @@ class DummyAttention(nn.Module):
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1: attn_output = self.o_proj(attn_output)
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
@@ -508,9 +467,10 @@ class DummySdpaAttention(DummyAttention):
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
@@ -794,7 +754,10 @@ class DummyModel(DummyPreTrainedModel):
) )
self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DummyRotaryEmbedding(config=config) self.rotary_emb = DummyRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -874,7 +837,7 @@ class DummyModel(DummyPreTrainedModel):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
for decoder_layer in self.layers: for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View File

@@ -667,7 +667,10 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
[MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -752,7 +755,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
for decoder_layer in self.layers: for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

View File

@@ -8,7 +8,6 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
@@ -150,25 +149,7 @@ class SuperMLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
if self.config.pretraining_tp > 1: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj return down_proj
@@ -264,31 +245,14 @@ class SuperAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1: query_states = self.q_proj(hidden_states)
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp key_states = self.k_proj(hidden_states)
query_slices = self.q_proj.weight.split( value_states = self.v_proj(hidden_states)
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = torch.cat(query_states, dim=-1) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
@@ -330,12 +294,7 @@ class SuperAttention(nn.Module):
attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1: attn_output = self.o_proj(attn_output)
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
@@ -508,9 +467,10 @@ class SuperSdpaAttention(SuperAttention):
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
@@ -794,7 +754,10 @@ class SuperModel(SuperPreTrainedModel):
) )
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = SuperRotaryEmbedding(config=config) self.rotary_emb = SuperRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()

View File

@@ -266,7 +266,6 @@ class SuperTransformer(cst.CSTTransformer):
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target) target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets: if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue continue
if target in self.all_assign_target: if target in self.all_assign_target:
stmt = self.all_assign_target[target] stmt = self.all_assign_target[target]
@@ -773,6 +772,8 @@ class ModelFileMapper(ModuleMapper):
self.object_dependency_mapping.update( self.object_dependency_mapping.update(
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
) )
# Add them to global nodes
self.global_nodes.update(self.functions)
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file. """Update the global nodes with the assignment from the modular file.
@@ -786,6 +787,8 @@ class ModelFileMapper(ModuleMapper):
self.assignments[assignment] = node self.assignments[assignment] = node
if assignment in object_mapping: if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment] self.object_dependency_mapping[assignment] = object_mapping[assignment]
# Add them to global nodes
self.global_nodes.update(self.assignments)
def _merge_classes(self, classes: dict[str, cst.CSTNode]): def _merge_classes(self, classes: dict[str, cst.CSTNode]):
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
@@ -813,10 +816,7 @@ class ModelFileMapper(ModuleMapper):
self._merge_classes(classes) self._merge_classes(classes)
self.modular_file_start_lines = start_lines self.modular_file_start_lines = start_lines
# Correctly re-set the global nodes at this point # Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities() self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments # Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@@ -1024,14 +1024,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
import_ref_count[name] = ref_count import_ref_count[name] = ref_count
imports_to_keep = [] imports_to_keep = []
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
for node in all_imports: for node in all_imports:
if m.matches(node, m.If()): # handle safe imports if m.matches(node, m.If()): # handle safe imports
new_statements = [] new_statements = []
for stmt_node in node.body.body: for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements) append_new_import_node(stmt_node, unused_imports, new_statements)
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
if len(new_statements) > 0: if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node) imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
else: else:
append_new_import_node(node, unused_imports, imports_to_keep) append_new_import_node(node, unused_imports, imports_to_keep)