Add strategy to store results in evaluation loop (#30267)
* Add evaluation loop container for interm. results * Add tests for EvalLoopContainer * Formatting * Fix padding_index in test and typo * Move EvalLoopContainer to pr_utils to avoid additional imports * Fix `eval_do_concat_batches` arg description * Fix EvalLoopContainer import
This commit is contained in:
committed by
GitHub
parent
8d6b509611
commit
c15aad0939
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
EvalLoopContainer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
@@ -497,3 +498,92 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
remove_columns_collator(data_batch)
|
||||
self.assertEqual(logger.called, 1)
|
||||
self.assertIn("col3", logger.last_msg)
|
||||
|
||||
def test_eval_loop_container(self):
|
||||
batch_1 = [
|
||||
torch.ones([8, 5]),
|
||||
{"loss": torch.tensor(1.0)},
|
||||
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
|
||||
]
|
||||
batch_2 = [
|
||||
torch.ones([4, 5]),
|
||||
{"loss": torch.tensor(2.0)},
|
||||
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
|
||||
]
|
||||
|
||||
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
|
||||
concat_container.add(batch_1)
|
||||
concat_container.add(batch_2)
|
||||
concat_container.to_cpu_and_numpy()
|
||||
arrays = concat_container.get_arrays()
|
||||
|
||||
# Test two nested batches concatenation
|
||||
self.assertIsInstance(arrays, list)
|
||||
self.assertEqual(len(arrays), 3)
|
||||
self.assertIsInstance(arrays[0], np.ndarray)
|
||||
self.assertEqual(arrays[0].shape, (12, 5))
|
||||
self.assertIsInstance(arrays[1], dict)
|
||||
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
|
||||
self.assertEqual(arrays[1]["loss"].shape, (2,))
|
||||
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
|
||||
self.assertIsInstance(arrays[2], tuple)
|
||||
self.assertEqual(len(arrays[2]), 2)
|
||||
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
|
||||
self.assertEqual(arrays[2][1].shape, (12, 6))
|
||||
# check that first batch padded with padding index -100 after concatenation
|
||||
self.assertEqual(arrays[2][1][0][2], -100)
|
||||
|
||||
# Test two batches with no concatenation
|
||||
list_container = EvalLoopContainer(do_nested_concat=False)
|
||||
list_container.add(batch_1)
|
||||
list_container.add(batch_2)
|
||||
list_container.to_cpu_and_numpy()
|
||||
arrays = list_container.get_arrays()
|
||||
|
||||
self.assertEqual(len(arrays), 2)
|
||||
self.assertIsInstance(arrays, list)
|
||||
np_batch_1, np_batch_2 = arrays
|
||||
|
||||
self.assertIsInstance(np_batch_1, list)
|
||||
self.assertEqual(len(np_batch_1), 3)
|
||||
self.assertIsInstance(np_batch_1[0], np.ndarray)
|
||||
self.assertIsInstance(np_batch_1[1], dict)
|
||||
self.assertIsInstance(np_batch_1[2], tuple)
|
||||
self.assertEqual(np_batch_1[0].shape, (8, 5))
|
||||
self.assertEqual(np_batch_1[1]["loss"].shape, ())
|
||||
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
|
||||
self.assertEqual(np_batch_1[2][1].shape, (8, 2))
|
||||
|
||||
self.assertIsInstance(np_batch_2, list)
|
||||
self.assertEqual(len(np_batch_2), 3)
|
||||
self.assertIsInstance(np_batch_2[0], np.ndarray)
|
||||
self.assertIsInstance(np_batch_2[1], dict)
|
||||
self.assertIsInstance(np_batch_2[2], tuple)
|
||||
self.assertEqual(np_batch_2[0].shape, (4, 5))
|
||||
self.assertEqual(np_batch_2[1]["loss"].shape, ())
|
||||
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
|
||||
self.assertEqual(np_batch_2[2][1].shape, (4, 6))
|
||||
|
||||
# Test no batches
|
||||
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
|
||||
self.assertIsNone(none_arr)
|
||||
|
||||
none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
|
||||
self.assertIsNone(none_arr)
|
||||
|
||||
# Test one batch
|
||||
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
|
||||
concat_container.add(batch_1)
|
||||
arrays = concat_container.get_arrays()
|
||||
self.assertIsInstance(arrays, list)
|
||||
self.assertEqual(len(arrays), 3)
|
||||
self.assertIsInstance(arrays[0], np.ndarray)
|
||||
self.assertEqual(arrays[0].shape, (8, 5))
|
||||
self.assertIsInstance(arrays[1], dict)
|
||||
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
|
||||
self.assertEqual(arrays[1]["loss"].shape, ())
|
||||
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
|
||||
self.assertIsInstance(arrays[2], tuple)
|
||||
self.assertEqual(len(arrays[2]), 2)
|
||||
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
|
||||
self.assertEqual(arrays[2][1].shape, (8, 2))
|
||||
|
||||
Reference in New Issue
Block a user