Fix slicing for 0-dim param (#36580)
* fix * switch to ellipsis instead * Add co-author Co-authored-by: fxmarty-amd <fxmarty-amd@users.noreply.github.com> * Add co-author second try Co-authored-by: fxmarty-amd <felmarty@amd.com>
This commit is contained in:
@@ -531,7 +531,7 @@ def shard_and_distribute_module(
|
|||||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param = param[:]
|
param = param[...]
|
||||||
if is_contiguous:
|
if is_contiguous:
|
||||||
param = param.contiguous()
|
param = param.contiguous()
|
||||||
|
|
||||||
|
|||||||
@@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
device_mesh,
|
device_mesh,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param = param[:]
|
param = param[...]
|
||||||
if casting_dtype is not None:
|
if casting_dtype is not None:
|
||||||
param = param.to(casting_dtype)
|
param = param.to(casting_dtype)
|
||||||
if to_contiguous:
|
if to_contiguous:
|
||||||
|
|||||||
Reference in New Issue
Block a user