XLA train step fixes (#17973)
* Copy inputs to train and test step before modifying them, as this breaks things * Add XLA tests, fix our loss functions to be XLA-compatible * make fixup * Update loss computation test to expect vector of per-sample losses * Patch loss for TFLED * Patch loss for TFAlbert * Add a tf_legacy_loss config flag that enables old loss functions * Stop using config.get() because it's not a dict * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it * make fixup * Add XLA-compatible RAG loss * Fix dtype of loss mask for TFAlbert * Fix test for XLNet too because it overrides the default one * make fixup * Fix config test * No more depending on GPU NaN behaviour * Add test, avoid potential zero division * Fix test item assignment * Fix loss computation masking test * make fixup * Fix dtype bugs
This commit is contained in:
@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
added_label = prepared_for_class[
|
||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
expected_loss_size = added_label.shape.as_list()[:1]
|
||||
|
||||
# `TFXLNetLMHeadModel` doesn't cut logits/labels
|
||||
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||
@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
input_ids = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(input_ids, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
loss = model(prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss with a tuple
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# Send to model
|
||||
loss = model(tuple_input[:-1])[0]
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
|
||||
@require_tf
|
||||
|
||||
@@ -42,6 +42,7 @@ config_common_kwargs = {
|
||||
"torchscript": True,
|
||||
"torch_dtype": "float16",
|
||||
"use_bfloat16": True,
|
||||
"tf_legacy_loss": True,
|
||||
"pruned_heads": {"a": 1},
|
||||
"tie_word_embeddings": False,
|
||||
"is_decoder": True,
|
||||
|
||||
@@ -23,6 +23,7 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from importlib import import_module
|
||||
from math import isnan
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import Dataset
|
||||
@@ -1284,12 +1285,7 @@ class TFModelTesterMixin:
|
||||
added_label = prepared_for_class[
|
||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
|
||||
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||
# if loss is causal lm loss, labels are shift, so that one label per batch
|
||||
# is cut
|
||||
loss_size = loss_size - self.model_tester.batch_size
|
||||
expected_loss_size = added_label.shape.as_list()[:1]
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
@@ -1298,12 +1294,26 @@ class TFModelTesterMixin:
|
||||
model_input = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(model_input, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss when we mask some positions
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
possible_input_names = {"input_ids", "pixel_values", "input_features"}
|
||||
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
|
||||
model_input = prepared_for_class.pop(input_name)
|
||||
if "labels" in prepared_for_class:
|
||||
labels = prepared_for_class["labels"].numpy()
|
||||
if len(labels.shape) > 1 and labels.shape[1] != 1:
|
||||
labels[0] = -100
|
||||
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
|
||||
loss = model(model_input, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
self.assertTrue(not np.any(np.isnan(loss.numpy())))
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
loss = model(prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
# Test that model correctly compute the loss with a tuple
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
@@ -1334,7 +1344,7 @@ class TFModelTesterMixin:
|
||||
# Send to model
|
||||
loss = model(tuple_input[:-1])[0]
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1397,6 +1407,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss1))
|
||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
|
||||
# We reinitialize the model here even though our learning rate was zero
|
||||
@@ -1412,6 +1423,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss2))
|
||||
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||
@@ -1437,6 +1449,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss3 = history3.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss3))
|
||||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
|
||||
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import os
|
||||
import tempfile
|
||||
from importlib import import_module
|
||||
from math import isnan
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
@@ -134,6 +135,72 @@ class TFCoreModelTesterMixin:
|
||||
outputs = run_in_graph_mode()
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@slow
|
||||
def test_xla_fit(self):
|
||||
# This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
# Is there a better way to remove these decoder inputs?
|
||||
prepared_for_class = {
|
||||
key: val
|
||||
for key, val in prepared_for_class.items()
|
||||
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||
}
|
||||
|
||||
possible_label_cols = {
|
||||
"labels",
|
||||
"label",
|
||||
"label_ids",
|
||||
"start_positions",
|
||||
"start_position",
|
||||
"end_positions",
|
||||
"end_position",
|
||||
"next_sentence_label",
|
||||
}
|
||||
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
|
||||
# Make sure it works with XLA!
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
verbose=0,
|
||||
)
|
||||
loss = history.history["loss"][0]
|
||||
self.assertTrue(not isnan(loss))
|
||||
val_loss = history.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss))
|
||||
|
||||
# Now test it with separate labels, to make sure that path works in XLA too.
|
||||
model = model_class(config)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
|
||||
history = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
loss = history.history["loss"][0]
|
||||
self.assertTrue(not isnan(loss))
|
||||
val_loss = history.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss))
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user