Fix distributed gather for tuples of tensors of varying sizes (#11071)
This commit is contained in:
@@ -276,11 +276,8 @@ def nested_new_like(arrays, num_samples, padding_index=-100):
|
||||
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
|
||||
|
||||
|
||||
def nested_expand_like(arrays, new_seq_length, padding_index=-100):
|
||||
def expand_like(arrays, new_seq_length, padding_index=-100):
|
||||
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays)
|
||||
|
||||
result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
|
||||
result[:, : arrays.shape[1]] = arrays
|
||||
return result
|
||||
@@ -293,13 +290,6 @@ def nested_truncate(tensors, limit):
|
||||
return tensors[:limit]
|
||||
|
||||
|
||||
def _get_first_shape(arrays):
|
||||
"""Return the shape of the first array found in the nested struct `arrays`."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return _get_first_shape(arrays[0])
|
||||
return arrays.shape
|
||||
|
||||
|
||||
class DistributedTensorGatherer:
|
||||
"""
|
||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
||||
@@ -367,21 +357,15 @@ class DistributedTensorGatherer:
|
||||
if self._storage is None:
|
||||
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
|
||||
self._offsets = list(range(0, self.total_samples, self.process_length))
|
||||
else:
|
||||
storage_shape = _get_first_shape(self._storage)
|
||||
arrays_shape = _get_first_shape(arrays)
|
||||
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
|
||||
# If we get new arrays that are too big too fit, we expand the shape fo the storage
|
||||
self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index)
|
||||
slice_len = self._nested_set_tensors(self._storage, arrays)
|
||||
|
||||
slice_len, self._storage = self._nested_set_tensors(self._storage, arrays)
|
||||
for i in range(self.world_size):
|
||||
self._offsets[i] += slice_len
|
||||
|
||||
def _nested_set_tensors(self, storage, arrays):
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
for x, y in zip(storage, arrays):
|
||||
slice_len = self._nested_set_tensors(x, y)
|
||||
return slice_len
|
||||
result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)]
|
||||
return result[0][0], type(arrays)(r[1] for r in result)
|
||||
assert (
|
||||
arrays.shape[0] % self.world_size == 0
|
||||
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
|
||||
@@ -391,10 +375,13 @@ class DistributedTensorGatherer:
|
||||
if len(arrays.shape) == 1:
|
||||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
||||
else:
|
||||
# Expand the array on the fly if needed.
|
||||
if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]:
|
||||
storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index)
|
||||
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
|
||||
i * slice_len : (i + 1) * slice_len
|
||||
]
|
||||
return slice_len
|
||||
return slice_len, storage
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
|
||||
@@ -82,6 +82,49 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(result[1][0], predictions))
|
||||
self.assertTrue(np.array_equal(result[1][1], predictions))
|
||||
|
||||
def test_distributed_tensor_gatherer_different_shapes(self):
|
||||
# Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
|
||||
world_size = 4
|
||||
num_samples = 21
|
||||
input_indices = [
|
||||
[0, 1, 6, 7, 12, 13, 18, 19],
|
||||
[2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
|
||||
[5, 11, 17, 2],
|
||||
]
|
||||
sequence_lengths = [8, 10, 13]
|
||||
|
||||
predictions = np.random.normal(size=(num_samples, 13))
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays(predictions[indices, :seq_length])
|
||||
result = gatherer.finalize()
|
||||
|
||||
# Remove the extra samples added at the end for a round multiple of num processes.
|
||||
actual_indices = [input_indices[0], input_indices[1][:-2], input_indices[2][:-1]]
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[indices, :seq_length], predictions[indices, :seq_length]))
|
||||
|
||||
# With nested tensors
|
||||
predictions = np.random.normal(size=(num_samples, 13))
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays([predictions[indices, :seq_length], predictions[indices]])
|
||||
result = gatherer.finalize()
|
||||
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[0][indices, :seq_length], predictions[indices, :seq_length]))
|
||||
self.assertTrue(np.array_equal(result[1], predictions))
|
||||
|
||||
# Check if works if varying seq_length is second
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays([predictions[indices], predictions[indices, :seq_length]])
|
||||
result = gatherer.finalize()
|
||||
|
||||
self.assertTrue(np.array_equal(result[0], predictions))
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[1][indices, :seq_length], predictions[indices, :seq_length]))
|
||||
|
||||
def test_label_smoothing(self):
|
||||
epsilon = 0.1
|
||||
num_labels = 12
|
||||
|
||||
Reference in New Issue
Block a user