From e3a5889ef09ed60444d5eff4314f1e87909e2739 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 19 Nov 2024 16:08:57 +0100 Subject: [PATCH] Modular fix (#34802) * Modular fix * style * remove logger warning * Update modular_model_converter.py --- .../configuration_my_new_model.py | 10 +++ .../configuration_my_new_model2.py | 10 +++ .../modular-transformers/modeling_dummy.py | 71 +++++-------------- .../modeling_my_new_model2.py | 5 +- .../modular-transformers/modeling_super.py | 69 +++++------------- utils/modular_model_converter.py | 13 ++-- 6 files changed, 65 insertions(+), 113 deletions(-) diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index aa0aac55ba..7042c586cb 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -130,6 +130,16 @@ class MyNewModelConfig(PretrainedConfig): model_type = "my_new_model" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `MyNewModelModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index f05ace94b6..eddd7fe479 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -33,6 +33,16 @@ class MyNewModel2Config(PretrainedConfig): model_type = "my_new_model2" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `MyNewModel2Model` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } def __init__( self, diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ed7e3c64d7..0b373d4e6e 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -8,7 +8,6 @@ import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -150,25 +149,7 @@ class DummyMLP(nn.Module): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -264,31 +245,14 @@ class DummyAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -330,12 +294,7 @@ class DummyAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -508,9 +467,10 @@ class DummySdpaAttention(DummyAttention): key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -794,7 +754,10 @@ class DummyModel(DummyPreTrainedModel): ) self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DummyRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -874,7 +837,7 @@ class DummyModel(DummyPreTrainedModel): all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 16f9e525a0..189e090094 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -667,7 +667,10 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -752,7 +755,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 7df04bcc2a..7ad606280d 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -8,7 +8,6 @@ import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -150,25 +149,7 @@ class SuperMLP(nn.Module): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -264,31 +245,14 @@ class SuperAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -330,12 +294,7 @@ class SuperAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, -1) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -508,9 +467,10 @@ class SuperSdpaAttention(SuperAttention): key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -794,7 +754,10 @@ class SuperModel(SuperPreTrainedModel): ) self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SuperRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index e5f6e34ece..ccf15363de 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -266,7 +266,6 @@ class SuperTransformer(cst.CSTTransformer): if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): target = self.python_module.code_for_node(stmt.body[0].targets[0].target) if target in self.deleted_targets: - logger.warning(f"Deleted the assign for {target}") continue if target in self.all_assign_target: stmt = self.all_assign_target[target] @@ -773,6 +772,8 @@ class ModelFileMapper(ModuleMapper): self.object_dependency_mapping.update( {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} ) + # Add them to global nodes + self.global_nodes.update(self.functions) def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): """Update the global nodes with the assignment from the modular file. @@ -786,6 +787,8 @@ class ModelFileMapper(ModuleMapper): self.assignments[assignment] = node if assignment in object_mapping: self.object_dependency_mapping[assignment] = object_mapping[assignment] + # Add them to global nodes + self.global_nodes.update(self.assignments) def _merge_classes(self, classes: dict[str, cst.CSTNode]): """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and @@ -813,10 +816,7 @@ class ModelFileMapper(ModuleMapper): self._merge_classes(classes) self.modular_file_start_lines = start_lines - # Correctly re-set the global nodes at this point - self.global_nodes.update(self.functions) - self.global_nodes.update(self.assignments) - # Restrict the dependency mappings to the know entities to avoid Python's built-ins + # Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports self._restrict_dependencies_to_known_entities() # Create the global mapping of recursive dependencies for functions and assignments self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() @@ -1024,14 +1024,17 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> import_ref_count[name] = ref_count imports_to_keep = [] + existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly for node in all_imports: if m.matches(node, m.If()): # handle safe imports new_statements = [] for stmt_node in node.body.body: append_new_import_node(stmt_node, unused_imports, new_statements) + new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] if len(new_statements) > 0: new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) imports_to_keep.append(new_node) + existing_protected_statements.update({str(stmt) for stmt in new_statements}) else: append_new_import_node(node, unused_imports, imports_to_keep)