Fix distributed gather for tuples of tensors of varying sizes (#11071)
This commit is contained in:
@@ -82,6 +82,49 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(result[1][0], predictions))
|
||||
self.assertTrue(np.array_equal(result[1][1], predictions))
|
||||
|
||||
def test_distributed_tensor_gatherer_different_shapes(self):
|
||||
# Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
|
||||
world_size = 4
|
||||
num_samples = 21
|
||||
input_indices = [
|
||||
[0, 1, 6, 7, 12, 13, 18, 19],
|
||||
[2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
|
||||
[5, 11, 17, 2],
|
||||
]
|
||||
sequence_lengths = [8, 10, 13]
|
||||
|
||||
predictions = np.random.normal(size=(num_samples, 13))
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays(predictions[indices, :seq_length])
|
||||
result = gatherer.finalize()
|
||||
|
||||
# Remove the extra samples added at the end for a round multiple of num processes.
|
||||
actual_indices = [input_indices[0], input_indices[1][:-2], input_indices[2][:-1]]
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[indices, :seq_length], predictions[indices, :seq_length]))
|
||||
|
||||
# With nested tensors
|
||||
predictions = np.random.normal(size=(num_samples, 13))
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays([predictions[indices, :seq_length], predictions[indices]])
|
||||
result = gatherer.finalize()
|
||||
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[0][indices, :seq_length], predictions[indices, :seq_length]))
|
||||
self.assertTrue(np.array_equal(result[1], predictions))
|
||||
|
||||
# Check if works if varying seq_length is second
|
||||
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
|
||||
for indices, seq_length in zip(input_indices, sequence_lengths):
|
||||
gatherer.add_arrays([predictions[indices], predictions[indices, :seq_length]])
|
||||
result = gatherer.finalize()
|
||||
|
||||
self.assertTrue(np.array_equal(result[0], predictions))
|
||||
for indices, seq_length in zip(actual_indices, sequence_lengths):
|
||||
self.assertTrue(np.array_equal(result[1][indices, :seq_length], predictions[indices, :seq_length]))
|
||||
|
||||
def test_label_smoothing(self):
|
||||
epsilon = 0.1
|
||||
num_labels = 12
|
||||
|
||||
Reference in New Issue
Block a user