Add strategy to store results in evaluation loop (#30267)

* Add evaluation loop container for interm. results

* Add tests for EvalLoopContainer

* Formatting

* Fix padding_index in test and typo

* Move EvalLoopContainer to pr_utils to avoid additional imports

* Fix `eval_do_concat_batches` arg description

* Fix EvalLoopContainer import
This commit is contained in:
Pavel Iakubovskii
2024-04-17 12:42:27 +01:00
committed by GitHub
parent 8d6b509611
commit c15aad0939
4 changed files with 175 additions and 59 deletions

View File

@@ -35,6 +35,7 @@ if is_torch_available():
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
@@ -497,3 +498,92 @@ class TrainerUtilsTest(unittest.TestCase):
remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg)
def test_eval_loop_container(self):
batch_1 = [
torch.ones([8, 5]),
{"loss": torch.tensor(1.0)},
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
]
batch_2 = [
torch.ones([4, 5]),
{"loss": torch.tensor(2.0)},
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
]
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
concat_container.add(batch_2)
concat_container.to_cpu_and_numpy()
arrays = concat_container.get_arrays()
# Test two nested batches concatenation
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (12, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, (2,))
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
self.assertEqual(arrays[2][1].shape, (12, 6))
# check that first batch padded with padding index -100 after concatenation
self.assertEqual(arrays[2][1][0][2], -100)
# Test two batches with no concatenation
list_container = EvalLoopContainer(do_nested_concat=False)
list_container.add(batch_1)
list_container.add(batch_2)
list_container.to_cpu_and_numpy()
arrays = list_container.get_arrays()
self.assertEqual(len(arrays), 2)
self.assertIsInstance(arrays, list)
np_batch_1, np_batch_2 = arrays
self.assertIsInstance(np_batch_1, list)
self.assertEqual(len(np_batch_1), 3)
self.assertIsInstance(np_batch_1[0], np.ndarray)
self.assertIsInstance(np_batch_1[1], dict)
self.assertIsInstance(np_batch_1[2], tuple)
self.assertEqual(np_batch_1[0].shape, (8, 5))
self.assertEqual(np_batch_1[1]["loss"].shape, ())
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
self.assertEqual(np_batch_1[2][1].shape, (8, 2))
self.assertIsInstance(np_batch_2, list)
self.assertEqual(len(np_batch_2), 3)
self.assertIsInstance(np_batch_2[0], np.ndarray)
self.assertIsInstance(np_batch_2[1], dict)
self.assertIsInstance(np_batch_2[2], tuple)
self.assertEqual(np_batch_2[0].shape, (4, 5))
self.assertEqual(np_batch_2[1]["loss"].shape, ())
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
self.assertEqual(np_batch_2[2][1].shape, (4, 6))
# Test no batches
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
self.assertIsNone(none_arr)
none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
self.assertIsNone(none_arr)
# Test one batch
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
arrays = concat_container.get_arrays()
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (8, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, ())
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
self.assertEqual(arrays[2][1].shape, (8, 2))