From bc3d5781e72dc2b27ae39f290bcc7336f9839db7 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 13 Mar 2025 12:16:13 +0100 Subject: [PATCH] Fix slicing for 0-dim param (#36580) * fix * switch to ellipsis instead * Add co-author Co-authored-by: fxmarty-amd * Add co-author second try Co-authored-by: fxmarty-amd --- src/transformers/integrations/tensor_parallel.py | 2 +- src/transformers/modeling_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 9e8a0dec76..18163f230e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -531,7 +531,7 @@ def shard_and_distribute_module( param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh ) else: - param = param[:] + param = param[...] if is_contiguous: param = param.contiguous() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 207ddafa97..77f842aa5f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model( device_mesh, ) else: - param = param[:] + param = param[...] if casting_dtype is not None: param = param.to(casting_dtype) if to_contiguous: