Rewrite TensorFlow train_step and test_step (#17057)
* Initial commit * Better label renaming * Remove breakpoint before pushing (this is your job) * Test a lot more in the Keras fit() test * make fixup * Clarify the case where we flatten y dicts into tensors * Clarify the case where we flatten y dicts into tensors * Extract label name remapping to a method
This commit is contained in:
@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_using_dummy_loss = None
|
||||
_label_to_output_map = None
|
||||
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
@@ -907,17 +908,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
function themselves.
|
||||
"""
|
||||
if loss == "passthrough":
|
||||
if metrics is not None:
|
||||
raise ValueError(
|
||||
"Passing metrics as a dict is not supported when using the internal loss! "
|
||||
"Please either compile the model with a loss, or remove the metrics argument. "
|
||||
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
|
||||
"loss."
|
||||
)
|
||||
logger.warning(
|
||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||
"To disable this behaviour, please pass a loss argument, or explicitly pass "
|
||||
"To disable this behaviour please pass a loss argument, or explicitly pass "
|
||||
"`loss=None` if you do not want your model to compute a loss."
|
||||
)
|
||||
loss = dummy_loss
|
||||
@@ -925,6 +919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
else:
|
||||
self._using_dummy_loss = False
|
||||
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
|
||||
# This argument got renamed, we need to support both versions
|
||||
if "steps_per_execution" in parent_args:
|
||||
super().compile(
|
||||
optimizer=optimizer,
|
||||
@@ -962,18 +957,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
)
|
||||
return self.hf_compute_loss(*args, **kwargs)
|
||||
|
||||
def get_label_to_output_name_mapping(self):
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
if self._label_to_output_map is not None:
|
||||
return self._label_to_output_map
|
||||
elif "start_positions" in arg_names:
|
||||
return {"start_positions": "start_logits", "end_positions": "end_logits"}
|
||||
elif "sentence_order_label" in arg_names:
|
||||
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
|
||||
elif "next_sentence_label" in arg_names:
|
||||
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
|
||||
elif "mc_labels" in arg_names:
|
||||
return {"labels": "logits", "mc_labels": "mc_logits"}
|
||||
else:
|
||||
return dict()
|
||||
|
||||
def train_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
|
||||
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
|
||||
|
||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||
as keys in the input dictionary, or as normal Keras labels.
|
||||
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
|
||||
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
|
||||
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
|
||||
that they are available to the model during the forward pass.
|
||||
"""
|
||||
|
||||
# These are the only transformations `Model.fit` applies to user-input
|
||||
# data when a `tf.data.Dataset` is provided.
|
||||
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
label_to_output = self.get_label_to_output_name_mapping()
|
||||
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||
if not self._using_dummy_loss:
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
@@ -981,8 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||
# if those keys are not already present in the input dict
|
||||
if self._using_dummy_loss and y is not None:
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
|
||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||
if isinstance(x, tf.Tensor):
|
||||
@@ -997,6 +1007,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
for key, val in y.items():
|
||||
if key in arg_names and key not in x:
|
||||
x[key] = val
|
||||
elif output_to_label.get(key, None) in arg_names and key not in x:
|
||||
x[output_to_label[key]] = val
|
||||
if y is None:
|
||||
y = {key: val for key, val in x.items() if key in label_kwargs}
|
||||
if not y and not self._using_dummy_loss:
|
||||
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
|
||||
|
||||
if isinstance(y, dict):
|
||||
# Rename labels at this point to match output heads
|
||||
y = {label_to_output.get(key, key): val for key, val in y.items()}
|
||||
|
||||
# Run forward pass.
|
||||
with tf.GradientTape() as tape:
|
||||
@@ -1004,15 +1024,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
if self._using_dummy_loss:
|
||||
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||
else:
|
||||
loss = None
|
||||
|
||||
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
|
||||
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
|
||||
if isinstance(y, dict) and len(y) == 1:
|
||||
if list(y.keys())[0] in y_pred.keys():
|
||||
y_pred = y_pred[list(y.keys())[0]]
|
||||
elif list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred[1]
|
||||
else:
|
||||
y_pred = y_pred[0]
|
||||
_, y = y.popitem()
|
||||
elif isinstance(y, dict):
|
||||
# If the labels are a dict, match keys from the output by name
|
||||
y_pred = {key: val for key, val in y_pred.items() if key in y}
|
||||
elif isinstance(y, tuple) or isinstance(y, list):
|
||||
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
|
||||
if list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred.to_tuple()[1:]
|
||||
else:
|
||||
y_pred = y_pred.to_tuple()
|
||||
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
|
||||
else:
|
||||
# If the labels are a single tensor, match them to the first non-loss tensor in the output
|
||||
if list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred[1]
|
||||
else:
|
||||
y_pred = y_pred[0]
|
||||
|
||||
if loss is None:
|
||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
|
||||
# Run backwards pass.
|
||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||
|
||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||
else:
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
@@ -1021,23 +1068,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return_metrics.update(result)
|
||||
else:
|
||||
return_metrics[metric.name] = result
|
||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
||||
del return_metrics["loss_loss"]
|
||||
return return_metrics
|
||||
|
||||
def test_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
|
||||
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
|
||||
|
||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||
as keys in the input dictionary, or as normal Keras labels.
|
||||
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
|
||||
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
|
||||
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
|
||||
that they are available to the model during the forward pass.
|
||||
"""
|
||||
# These are the only transformations `Model.fit` applies to user-input
|
||||
# data when a `tf.data.Dataset` is provided.
|
||||
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
label_to_output = self.get_label_to_output_name_mapping()
|
||||
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||
if not self._using_dummy_loss:
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
@@ -1046,7 +1090,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# if those keys are not already present in the input dict
|
||||
if self._using_dummy_loss and y is not None:
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||
if isinstance(x, tf.Tensor):
|
||||
@@ -1061,19 +1104,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
for key, val in y.items():
|
||||
if key in arg_names and key not in x:
|
||||
x[key] = val
|
||||
elif output_to_label.get(key, None) in arg_names and key not in x:
|
||||
x[output_to_label[key]] = val
|
||||
if y is None:
|
||||
y = {key: val for key, val in x.items() if key in label_kwargs}
|
||||
if not y and not self._using_dummy_loss:
|
||||
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
|
||||
|
||||
if isinstance(y, dict):
|
||||
# Rename labels at this point to match output heads
|
||||
y = {label_to_output.get(key, key): val for key, val in y.items()}
|
||||
|
||||
# Run forward pass.
|
||||
y_pred = self(x, training=False)
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||
else:
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
loss = None
|
||||
|
||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
|
||||
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
|
||||
if isinstance(y, dict) and len(y) == 1:
|
||||
if list(y.keys())[0] in y_pred.keys():
|
||||
y_pred = y_pred[list(y.keys())[0]]
|
||||
elif list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred[1]
|
||||
else:
|
||||
y_pred = y_pred[0]
|
||||
_, y = y.popitem()
|
||||
elif isinstance(y, dict):
|
||||
# If the labels are a dict, match keys from the output by name
|
||||
y_pred = {key: val for key, val in y_pred.items() if key in y}
|
||||
elif isinstance(y, tuple) or isinstance(y, list):
|
||||
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
|
||||
if list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred.to_tuple()[1:]
|
||||
else:
|
||||
y_pred = y_pred.to_tuple()
|
||||
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
|
||||
else:
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# If the labels are a single tensor, match them to the first non-loss tensor in the output
|
||||
if list(y_pred.keys())[0] == "loss":
|
||||
y_pred = y_pred[1]
|
||||
else:
|
||||
y_pred = y_pred[0]
|
||||
|
||||
if loss is None:
|
||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
@@ -1082,10 +1161,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return_metrics.update(result)
|
||||
else:
|
||||
return_metrics[metric.name] = result
|
||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
||||
del return_metrics["loss_loss"]
|
||||
return return_metrics
|
||||
|
||||
def create_model_card(
|
||||
|
||||
@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
accuracy_classes = [
|
||||
"ForPreTraining",
|
||||
"ForCausalLM",
|
||||
"ForMaskedLM",
|
||||
"ForQuestionAnswering",
|
||||
"ForMultipleChoice",
|
||||
"ForSequenceClassification",
|
||||
"ForTokenClassification",
|
||||
"ForNextSentencePrediction",
|
||||
"LMHeadModel",
|
||||
]
|
||||
for accuracy_class in accuracy_classes:
|
||||
if model.__class__.__name__.endswith(accuracy_class):
|
||||
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
|
||||
break
|
||||
else:
|
||||
metrics = []
|
||||
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||
for key in history1.history.keys():
|
||||
if not key.startswith("val_"):
|
||||
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||
|
||||
def test_int64_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user