XLA train step fixes (#17973)
* Copy inputs to train and test step before modifying them, as this breaks things * Add XLA tests, fix our loss functions to be XLA-compatible * make fixup * Update loss computation test to expect vector of per-sample losses * Patch loss for TFLED * Patch loss for TFAlbert * Add a tf_legacy_loss config flag that enables old loss functions * Stop using config.get() because it's not a dict * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it * make fixup * Add XLA-compatible RAG loss * Fix dtype of loss mask for TFAlbert * Fix test for XLNet too because it overrides the default one * make fixup * Fix config test * No more depending on GPU NaN behaviour * Add test, avoid potential zero division * Fix test item assignment * Fix loss computation masking test * make fixup * Fix dtype bugs
This commit is contained in:
@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
added_label = prepared_for_class[
|
||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
expected_loss_size = added_label.shape.as_list()[:1]
|
||||
|
||||
# `TFXLNetLMHeadModel` doesn't cut logits/labels
|
||||
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||
@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
input_ids = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(input_ids, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
loss = model(prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss with a tuple
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# Send to model
|
||||
loss = model(tuple_input[:-1])[0]
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
|
||||
@require_tf
|
||||
|
||||
Reference in New Issue
Block a user