Return scalar losses instead of per-sample means (#18013)

* Return scalar losses instead of per-sample means

* Make loss shape (1,) instead of scalar

* Allow scalar losses in test_loss_computation

* Allow scalar losses in test_loss_computation

* Allow scalar losses in test_loss_computation

* Remove XLA loss function for RAG
This commit is contained in:
Matt
2022-07-04 17:26:19 +01:00
committed by GitHub
parent 6cb19540c9
commit 96d833b211
7 changed files with 39 additions and 63 deletions

View File

@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
@require_tf