Ensure tensors are at least 1d for pad and concat (#17179)
* Ensure tensors are at least 1d for pad and concat * Compatibility * Fix * Fix * Add test * Retrigger CI * Consistency with master * Retrigger CI
This commit is contained in:
@@ -55,8 +55,22 @@ except ImportError:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
|
||||||
|
if isinstance(tensor_or_array, torch.Tensor):
|
||||||
|
if hasattr(torch, "atleast_1d"):
|
||||||
|
tensor_or_array = torch.atleast_1d(tensor_or_array)
|
||||||
|
elif tensor_or_array.ndim < 1:
|
||||||
|
tensor_or_array = tensor_or_array[None]
|
||||||
|
else:
|
||||||
|
tensor_or_array = np.atleast_1d(tensor_or_array)
|
||||||
|
return tensor_or_array
|
||||||
|
|
||||||
|
|
||||||
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
||||||
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
||||||
|
tensor1 = atleast_1d(tensor1)
|
||||||
|
tensor2 = atleast_1d(tensor2)
|
||||||
|
|
||||||
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
||||||
return torch.cat((tensor1, tensor2), dim=0)
|
return torch.cat((tensor1, tensor2), dim=0)
|
||||||
|
|
||||||
@@ -72,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
|||||||
|
|
||||||
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
||||||
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
||||||
|
array1 = atleast_1d(array1)
|
||||||
|
array2 = atleast_1d(array2)
|
||||||
|
|
||||||
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
||||||
return np.concatenate((array1, array2), axis=0)
|
return np.concatenate((array1, array2), axis=0)
|
||||||
|
|
||||||
@@ -149,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name):
|
|||||||
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
if isinstance(tensors, (list, tuple)):
|
||||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||||
if tensors.ndim == 0:
|
tensors = atleast_1d(tensors)
|
||||||
tensors = tensors[None]
|
|
||||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||||
else:
|
else:
|
||||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||||
@@ -160,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
|
|||||||
try:
|
try:
|
||||||
if isinstance(tensor, (tuple, list)):
|
if isinstance(tensor, (tuple, list)):
|
||||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||||
if len(tensor.shape) <= 0:
|
tensor = atleast_1d(tensor)
|
||||||
tensor = tensor[None]
|
|
||||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||||
dist.all_gather(output_tensors, tensor)
|
dist.all_gather(output_tensors, tensor)
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
@@ -1031,7 +1046,7 @@ if is_sagemaker_mp_enabled():
|
|||||||
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||||
)
|
)
|
||||||
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||||
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
|
all_tensors = [atleast_1d(t) for t in all_tensors]
|
||||||
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||||
|
|
||||||
def smp_nested_concat(tensor):
|
def smp_nested_concat(tensor):
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ if is_torch_available():
|
|||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
ShardSampler,
|
ShardSampler,
|
||||||
get_parameter_names,
|
get_parameter_names,
|
||||||
|
numpy_pad_and_concatenate,
|
||||||
|
torch_pad_and_concatenate,
|
||||||
)
|
)
|
||||||
|
|
||||||
class TstLayer(nn.Module):
|
class TstLayer(nn.Module):
|
||||||
@@ -459,6 +461,18 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
mock_training_loop_function()
|
mock_training_loop_function()
|
||||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||||
|
|
||||||
|
def test_pad_and_concatenate_with_1d(self):
|
||||||
|
"""Tests whether pad_and_concatenate works with scalars."""
|
||||||
|
array1 = 1.0
|
||||||
|
array2 = 2.0
|
||||||
|
result = numpy_pad_and_concatenate(array1, array2)
|
||||||
|
self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result))
|
||||||
|
|
||||||
|
tensor1 = torch.tensor(1.0)
|
||||||
|
tensor2 = torch.tensor(2.0)
|
||||||
|
result = torch_pad_and_concatenate(tensor1, tensor2)
|
||||||
|
self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0])))
|
||||||
|
|
||||||
def test_remove_columns_collator(self):
|
def test_remove_columns_collator(self):
|
||||||
class MockLogger:
|
class MockLogger:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user