Feat: add warnings for unused keys and rules in tensor parallel (#37893)

Feat: tensor parallel plan verification
This commit is contained in:
Matej Sirovatka
2025-05-16 14:52:47 +02:00
committed by GitHub
parent 120935234f
commit 7b5e327c6e
2 changed files with 35 additions and 0 deletions

View File

@@ -64,6 +64,7 @@ from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
shard_and_distribute_module,
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
@@ -4974,6 +4975,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
if logger.level >= logging.WARNING:
verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
# Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys)