[WIP] add deepseek-v3 (#35926)
* init commit * style * take comments into account * add deepseekv3 modeling * remove redundant code * apply make style * apply fix-copies * make format * add init files * rename deepseekv3 into deepseek_v3 based on its model_type * rename deepseekv3 into deepseek_v3 based on its model_type * deepseek-v3 not deepseek_v3 * set model_type as deepseek_v3 * use default docs * apply make * fill type and docstring * add rope_config_validation * use custom DeepseekV3MLP * hold code only for checkpoints congifuration; remove redundant * revise rope yarn for DeepSeek variation * rename DeepSeek-V3 * some refactoring * revise load_hook to work properly; make moe func trainable; use llama instead of mixtral * fix attention forward * use -1 for not-changing dim when to use exapnd * refactor DeepseekV3TopkRouter * use reshape_for_rope instead of load_hook; revise attention forward for TP; rename q_head_dim with qk_head_dim * register pre_hook and hook both * make style * use n_shared_experts * Update src/transformers/models/deepseek_v3/configuration_deepseek_v3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add test file * update modeling_file according to modular file * make style * add mapping for DeepseekV3ForSequenceClassification * remove aux_loss_alpha * add deepseek_v3 for perf * add deepseek_v3 * rename test as deepseekv3 * use tiny-deepseek-v3 * remove DeepseekV3ForSequenceClassification * cache before padding * remote output_router_logits * Revert "remote output_router_logits" This reverts commit f264f800d04950390db8413b9efb24cef8186330. * remove output_router_logits * make e_score_correction_bias as buffer * skip tests not compatible * make style * make e_score_correction_bias as buffer * use rope_interleave instead of load_hook * skip tests not compatible with MLA * add doc for rope_interleave * fix typo * remove torch.no_grad for selecting topk * fix post merge issue * mrege with main and simplify * nits * final * small fixes * fix * support TP better * stash * changes currently requires * remove synch * more fixes for TP * temp fix for TP : some attention layers's FP8 scales are too small + shared is local colwise and anything is local if FP8 because weights are used * updates to have generation work! * push most of the changes * reorder functions + call for contributions! * update readme * nits * update * ruff was updated on main * merge with main and fix copies * revert unrelated changes * route all tokens to all experts when testing to avoid no gradient iddues * finish fixing all tests * fixup * nit * clean config * last readme changes * nit * do cnit * typo * last nit * one more one more --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-165-131.ec2.internal>
This commit is contained in:
@@ -779,8 +779,7 @@ def _load_state_dict_into_meta_model(
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_quantized
|
||||
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors")
|
||||
file_pointer = None
|
||||
if is_meta_state_dict:
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
@@ -795,7 +794,7 @@ def _load_state_dict_into_meta_model(
|
||||
serialized_param_name = reverse_renaming_mapping[param_name]
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
else:
|
||||
param = empty_param # It is actually not empty!
|
||||
param = empty_param.to(tensor_device) # It is actually not empty!
|
||||
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
@@ -813,7 +812,7 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
tensor_device, # the rank
|
||||
int(os.environ["RANK"]), # the rank
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
@@ -4102,11 +4101,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
logger.warning("Tensor Parallel requires torch.distributed to be initialized first.")
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
@@ -4115,12 +4115,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
# Get device with index assuming equal number of devices per host
|
||||
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
|
||||
tp_device = torch.device(device_type, torch.cuda.current_device())
|
||||
if tp_device.index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||
@@ -4871,9 +4872,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and hf_quantizer is None:
|
||||
if device_map is not None: # and hf_quantizer is None:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
error_msgs = []
|
||||
mismatched_keys = []
|
||||
@@ -5834,7 +5835,7 @@ def expand_device_map(device_map, param_names):
|
||||
return new_device_map
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
the model, which is actually the loading speed botteneck.
|
||||
@@ -5865,7 +5866,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
||||
if _torch_distributed_available and torch.distributed.is_initialized()
|
||||
else None
|
||||
)
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
@@ -5886,7 +5886,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
||||
# Allow up to 95% of max device memory
|
||||
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, weight_map):
|
||||
|
||||
Reference in New Issue
Block a user