From 47412c7d434f6ddfc02a9b7ecd6182b86ae0a164 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 11 May 2022 19:19:08 +0200 Subject: [PATCH] Ensure tensors are at least 1d for pad and concat (#17179) * Ensure tensors are at least 1d for pad and concat * Compatibility * Fix * Fix * Add test * Retrigger CI * Consistency with master * Retrigger CI --- src/transformers/trainer_pt_utils.py | 25 ++++++++++++++++++++----- tests/trainer/test_trainer_utils.py | 14 ++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index ac83826e40..fa1596f4c6 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -55,8 +55,22 @@ except ImportError: logger = logging.get_logger(__name__) +def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): + if isinstance(tensor_or_array, torch.Tensor): + if hasattr(torch, "atleast_1d"): + tensor_or_array = torch.atleast_1d(tensor_or_array) + elif tensor_or_array.ndim < 1: + tensor_or_array = tensor_or_array[None] + else: + tensor_or_array = np.atleast_1d(tensor_or_array) + return tensor_or_array + + def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100): """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary.""" + tensor1 = atleast_1d(tensor1) + tensor2 = atleast_1d(tensor2) + if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]: return torch.cat((tensor1, tensor2), dim=0) @@ -72,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100): def numpy_pad_and_concatenate(array1, array2, padding_index=-100): """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary.""" + array1 = atleast_1d(array1) + array2 = atleast_1d(array2) + if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]: return np.concatenate((array1, array2), axis=0) @@ -149,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name): if isinstance(tensors, (list, tuple)): return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) - if tensors.ndim == 0: - tensors = tensors[None] + tensors = atleast_1d(tensors) return xm.mesh_reduce(name, tensors, torch.cat) else: raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") @@ -160,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> try: if isinstance(tensor, (tuple, list)): return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) - if len(tensor.shape) <= 0: - tensor = tensor[None] + tensor = atleast_1d(tensor) output_tensors = [tensor.clone() for _ in range(dist.get_world_size())] dist.all_gather(output_tensors, tensor) concat = torch.cat(output_tensors, dim=0) @@ -1031,7 +1046,7 @@ if is_sagemaker_mp_enabled(): f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." ) all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP) - all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors] + all_tensors = [atleast_1d(t) for t in all_tensors] return torch.cat([t.cpu() for t in all_tensors], dim=0) def smp_nested_concat(tensor): diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 168beb95b9..869d19b0a1 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -41,6 +41,8 @@ if is_torch_available(): SequentialDistributedSampler, ShardSampler, get_parameter_names, + numpy_pad_and_concatenate, + torch_pad_and_concatenate, ) class TstLayer(nn.Module): @@ -459,6 +461,18 @@ class TrainerUtilsTest(unittest.TestCase): mock_training_loop_function() self.assertEqual("CUDA out of memory", cm.args[0]) + def test_pad_and_concatenate_with_1d(self): + """Tests whether pad_and_concatenate works with scalars.""" + array1 = 1.0 + array2 = 2.0 + result = numpy_pad_and_concatenate(array1, array2) + self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result)) + + tensor1 = torch.tensor(1.0) + tensor2 = torch.tensor(2.0) + result = torch_pad_and_concatenate(tensor1, tensor2) + self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0]))) + def test_remove_columns_collator(self): class MockLogger: def __init__(self) -> None: