XLA train step fixes (#17973)

* Copy inputs to train and test step before modifying them, as this breaks things

* Add XLA tests, fix our loss functions to be XLA-compatible

* make fixup

* Update loss computation test to expect vector of per-sample losses

* Patch loss for TFLED

* Patch loss for TFAlbert

* Add a tf_legacy_loss config flag that enables old loss functions

* Stop using config.get() because it's not a dict

* Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it

* make fixup

* Add XLA-compatible RAG loss

* Fix dtype of loss mask for TFAlbert

* Fix test for XLNet too because it overrides the default one

* make fixup

* Fix config test

* No more depending on GPU NaN behaviour

* Add test, avoid potential zero division

* Fix test item assignment

* Fix loss computation masking test

* make fixup

* Fix dtype bugs
This commit is contained in:
Matt
2022-07-01 19:11:14 +01:00
committed by GitHub
parent 485bbe79d5
commit d6cec45801
10 changed files with 278 additions and 83 deletions

View File

@@ -23,6 +23,7 @@ import tempfile
import unittest
import unittest.mock as mock
from importlib import import_module
from math import isnan
from typing import List, Tuple
from datasets import Dataset
@@ -1284,12 +1285,7 @@ class TFModelTesterMixin:
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size
expected_loss_size = added_label.shape.as_list()[:1]
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -1298,12 +1294,26 @@ class TFModelTesterMixin:
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
@@ -1334,7 +1344,7 @@ class TFModelTesterMixin:
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertEqual(loss.shape, [loss_size])
self.assertEqual(loss.shape.as_list(), expected_loss_size)
def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1397,6 +1407,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
self.assertTrue(not isnan(val_loss1))
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
@@ -1412,6 +1423,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history2.history.keys())
@@ -1437,6 +1449,7 @@ class TFModelTesterMixin:
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history3.history.keys())