Feat: add warnings for unused keys and rules in tensor parallel (#37893)
Feat: tensor parallel plan verification
This commit is contained in:
@@ -64,6 +64,7 @@ from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
shard_and_distribute_module,
|
||||
verify_tp_plan,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
@@ -4974,6 +4975,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
if logger.level >= logging.WARNING:
|
||||
verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and not is_hqq_or_quark:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
|
||||
Reference in New Issue
Block a user