Fix train_step, test_step and tests for CLIP (#18684)

* Fix train_step and test_step, correctly enable CLIP fit test

* Stop using get_args on older Python versions

* Don't use get_origin either

* UnionType is actually even newer, don't use that either

* Apply the same fix to test_loss_computation

* Just realized I was accidentally skipping a bunch of tests!

* Fix test_loss_computation for models without separable labels

* Fix scalar losses in test_step and train_step

* Stop committing your breakpoints

* Fix Swin loss shape

* Fix Tapas loss shape

* Shape fixes for TAPAS, DeIT, HuBERT and ViTMAE

* Add loss computation to TFMobileBertForPreTraining

* make fixup and move copied from statement

* make fixup and move copied from statement

* Correct copied from

* Add labels and next_sentence_label inputs to TFMobileBERT

* Make sure total_loss is always defined

* Update tests/test_modeling_tf_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix copied from

* Ensure CTC models get labels in tests

* Ensure CTC models get labels in tests

* Fix tests for vit_mae

* Fix tests for vit_mae

* Fix tests for vit_mae

* Reduce batch size for wav2vec2 testing because it was causing OOM

* Skip some TAPAS tests that are failing

* Skip a failing HuBERT test

* make style

* Fix mobilebertforpretraining test

* Skip Wav2Vec2 tests that use huge amounts of mem

* Skip keras_fit for Wav2Vec2 as well

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2022-09-09 20:01:02 +01:00
committed by GitHub
parent f1a6df3210
commit 660e0b97bd
13 changed files with 294 additions and 162 deletions

View File

@@ -22,9 +22,10 @@ import random
import tempfile
import unittest
import unittest.mock as mock
from dataclasses import fields
from importlib import import_module
from math import isnan
from typing import List, Tuple
from typing import List, Tuple, get_type_hints
from datasets import Dataset
@@ -124,6 +125,26 @@ def _config_zero_init(config):
return configs_no_init
def _return_type_has_loss(model):
return_type = get_type_hints(model.call)
if "return" not in return_type:
return False
return_type = return_type["return"]
if hasattr(return_type, "__args__"): # Awkward check for union because UnionType only turns up in 3.10
for type_annotation in return_type.__args__:
if inspect.isclass(type_annotation) and issubclass(type_annotation, ModelOutput):
field_names = [field.name for field in fields(type_annotation)]
if "loss" in field_names:
return True
return False
elif isinstance(return_type, tuple):
return False
elif isinstance(return_type, ModelOutput):
class_fields = fields(return_type)
return "loss" in class_fields
return False
@require_tf
class TFModelTesterMixin:
@@ -170,7 +191,7 @@ class TFModelTesterMixin:
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
*get_values(TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING),
]:
] and "labels" in dict(inspect.signature(model_class.call).parameters):
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
@@ -182,6 +203,11 @@ class TFModelTesterMixin:
elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)
elif model_class.__name__.endswith("ForCTC"):
# When we have enough CTC models for an AutoClass, we should use their mapping instead of name checks
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
return inputs_dict
@@ -1335,72 +1361,74 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
expected_loss_size = added_label.shape.as_list()[:1]
if not getattr(model, "hf_compute_loss", None) and not _return_type_has_loss(model):
continue
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label_names = sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)
if not added_label_names:
continue # This test is only for models with easily-separable labels
added_label = prepared_for_class[added_label_names[0]]
expected_loss_size = added_label.shape.as_list()[:1]
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features", "input_values"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features", "input_values"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: input_name}
for label_key in label_keys:
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: input_name}
for label_key in label_keys:
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []
for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)
for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)
for index, value in sorted_tuple_index_mapping:
list_input[index] = prepared_for_class[value]
for index, value in sorted_tuple_index_mapping:
list_input[index] = prepared_for_class[value]
tuple_input = tuple(list_input)
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input[:-1])[0]
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
@@ -1409,111 +1437,118 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
}
if not getattr(model, "hf_compute_loss", False) and not _return_type_has_loss(model):
continue
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
# We also remove "return_loss" as this is covered by the train_step when using fit()
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key
not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids", "return_loss")
}
possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
accuracy_classes = [
"ForPreTraining",
"ForCausalLM",
"ForMaskedLM",
"ForQuestionAnswering",
"ForMultipleChoice",
"ForSequenceClassification",
"ForTokenClassification",
"ForNextSentencePrediction",
"LMHeadModel",
]
for accuracy_class in accuracy_classes:
if model.__class__.__name__.endswith(accuracy_class):
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
break
else:
metrics = []
accuracy_classes = [
"ForPreTraining",
"ForCausalLM",
"ForMaskedLM",
"ForQuestionAnswering",
"ForMultipleChoice",
"ForSequenceClassification",
"ForTokenClassification",
"ForNextSentencePrediction",
"LMHeadModel",
]
for accuracy_class in accuracy_classes:
if model.__class__.__name__.endswith(accuracy_class):
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
break
else:
metrics = []
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
model(model.dummy_inputs) # Build the model so we can get some constant weights
model_weights = model.get_weights()
# Run eagerly to save some expensive compilation times
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
self.assertTrue(not isnan(val_loss1))
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# Run eagerly to save some expensive compilation times
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
self.assertTrue(not isnan(val_loss1))
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
if len(label_names) == 0:
# The next tests only make sense for models with separate inputs and labels, and do not make
# sense for models that don't clearly distinguish between the two (e.g. CLIP)
return
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
history2 = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.check_keras_fit_results(val_loss1, val_loss2)
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
history2 = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.check_keras_fit_results(val_loss1, val_loss2)
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
history3 = model.fit(
dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.check_keras_fit_results(val_loss1, val_loss3)
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)
history3 = model.fit(
dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.check_keras_fit_results(val_loss1, val_loss3)
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()