Rewrite TensorFlow train_step and test_step (#17057)
* Initial commit * Better label renaming * Remove breakpoint before pushing (this is your job) * Test a lot more in the Keras fit() test * make fixup * Clarify the case where we flatten y dicts into tensors * Clarify the case where we flatten y dicts into tensors * Extract label name remapping to a method
This commit is contained in:
@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
_auto_class = None
|
_auto_class = None
|
||||||
_using_dummy_loss = None
|
_using_dummy_loss = None
|
||||||
|
_label_to_output_map = None
|
||||||
|
|
||||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||||
# (and avoid unnecessary warnings).
|
# (and avoid unnecessary warnings).
|
||||||
@@ -907,17 +908,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
function themselves.
|
function themselves.
|
||||||
"""
|
"""
|
||||||
if loss == "passthrough":
|
if loss == "passthrough":
|
||||||
if metrics is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Passing metrics as a dict is not supported when using the internal loss! "
|
|
||||||
"Please either compile the model with a loss, or remove the metrics argument. "
|
|
||||||
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
|
|
||||||
"loss."
|
|
||||||
)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||||
"To disable this behaviour, please pass a loss argument, or explicitly pass "
|
"To disable this behaviour please pass a loss argument, or explicitly pass "
|
||||||
"`loss=None` if you do not want your model to compute a loss."
|
"`loss=None` if you do not want your model to compute a loss."
|
||||||
)
|
)
|
||||||
loss = dummy_loss
|
loss = dummy_loss
|
||||||
@@ -925,6 +919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
else:
|
else:
|
||||||
self._using_dummy_loss = False
|
self._using_dummy_loss = False
|
||||||
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
|
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
|
||||||
|
# This argument got renamed, we need to support both versions
|
||||||
if "steps_per_execution" in parent_args:
|
if "steps_per_execution" in parent_args:
|
||||||
super().compile(
|
super().compile(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
@@ -962,18 +957,34 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
return self.hf_compute_loss(*args, **kwargs)
|
return self.hf_compute_loss(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_label_to_output_name_mapping(self):
|
||||||
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
|
if self._label_to_output_map is not None:
|
||||||
|
return self._label_to_output_map
|
||||||
|
elif "start_positions" in arg_names:
|
||||||
|
return {"start_positions": "start_logits", "end_positions": "end_logits"}
|
||||||
|
elif "sentence_order_label" in arg_names:
|
||||||
|
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
|
||||||
|
elif "next_sentence_label" in arg_names:
|
||||||
|
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
|
||||||
|
elif "mc_labels" in arg_names:
|
||||||
|
return {"labels": "logits", "mc_labels": "mc_logits"}
|
||||||
|
else:
|
||||||
|
return dict()
|
||||||
|
|
||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""
|
"""
|
||||||
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
|
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
|
||||||
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
|
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
|
||||||
|
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
|
||||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
that they are available to the model during the forward pass.
|
||||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
|
||||||
as keys in the input dictionary, or as normal Keras labels.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# These are the only transformations `Model.fit` applies to user-input
|
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
|
||||||
# data when a `tf.data.Dataset` is provided.
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
|
label_kwargs = find_labels(self.__class__)
|
||||||
|
label_to_output = self.get_label_to_output_name_mapping()
|
||||||
|
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||||
if not self._using_dummy_loss:
|
if not self._using_dummy_loss:
|
||||||
data = data_adapter.expand_1d(data)
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
@@ -981,8 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||||
# if those keys are not already present in the input dict
|
# if those keys are not already present in the input dict
|
||||||
if self._using_dummy_loss and y is not None:
|
if self._using_dummy_loss and y is not None:
|
||||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
|
||||||
label_kwargs = find_labels(self.__class__)
|
|
||||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||||
if isinstance(x, tf.Tensor):
|
if isinstance(x, tf.Tensor):
|
||||||
@@ -997,6 +1007,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
for key, val in y.items():
|
for key, val in y.items():
|
||||||
if key in arg_names and key not in x:
|
if key in arg_names and key not in x:
|
||||||
x[key] = val
|
x[key] = val
|
||||||
|
elif output_to_label.get(key, None) in arg_names and key not in x:
|
||||||
|
x[output_to_label[key]] = val
|
||||||
|
if y is None:
|
||||||
|
y = {key: val for key, val in x.items() if key in label_kwargs}
|
||||||
|
if not y and not self._using_dummy_loss:
|
||||||
|
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
|
||||||
|
|
||||||
|
if isinstance(y, dict):
|
||||||
|
# Rename labels at this point to match output heads
|
||||||
|
y = {label_to_output.get(key, key): val for key, val in y.items()}
|
||||||
|
|
||||||
# Run forward pass.
|
# Run forward pass.
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
@@ -1004,14 +1024,41 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if self._using_dummy_loss:
|
if self._using_dummy_loss:
|
||||||
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||||
else:
|
else:
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
|
||||||
|
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
|
||||||
|
if isinstance(y, dict) and len(y) == 1:
|
||||||
|
if list(y.keys())[0] in y_pred.keys():
|
||||||
|
y_pred = y_pred[list(y.keys())[0]]
|
||||||
|
elif list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred[1]
|
||||||
|
else:
|
||||||
|
y_pred = y_pred[0]
|
||||||
|
_, y = y.popitem()
|
||||||
|
elif isinstance(y, dict):
|
||||||
|
# If the labels are a dict, match keys from the output by name
|
||||||
|
y_pred = {key: val for key, val in y_pred.items() if key in y}
|
||||||
|
elif isinstance(y, tuple) or isinstance(y, list):
|
||||||
|
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
|
||||||
|
if list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred.to_tuple()[1:]
|
||||||
|
else:
|
||||||
|
y_pred = y_pred.to_tuple()
|
||||||
|
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
|
||||||
|
else:
|
||||||
|
# If the labels are a single tensor, match them to the first non-loss tensor in the output
|
||||||
|
if list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred[1]
|
||||||
|
else:
|
||||||
|
y_pred = y_pred[0]
|
||||||
|
|
||||||
|
if loss is None:
|
||||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
|
|
||||||
# Run backwards pass.
|
# Run backwards pass.
|
||||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||||
|
|
||||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
|
||||||
if self._using_dummy_loss:
|
|
||||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
|
||||||
else:
|
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
return_metrics = {}
|
return_metrics = {}
|
||||||
@@ -1021,23 +1068,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
return_metrics.update(result)
|
return_metrics.update(result)
|
||||||
else:
|
else:
|
||||||
return_metrics[metric.name] = result
|
return_metrics[metric.name] = result
|
||||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
|
||||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
|
||||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
|
||||||
del return_metrics["loss_loss"]
|
|
||||||
return return_metrics
|
return return_metrics
|
||||||
|
|
||||||
def test_step(self, data):
|
def test_step(self, data):
|
||||||
"""
|
"""
|
||||||
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
|
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
|
||||||
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
|
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
|
||||||
|
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
|
||||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
that they are available to the model during the forward pass.
|
||||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
|
||||||
as keys in the input dictionary, or as normal Keras labels.
|
|
||||||
"""
|
"""
|
||||||
# These are the only transformations `Model.fit` applies to user-input
|
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
|
||||||
# data when a `tf.data.Dataset` is provided.
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
|
label_kwargs = find_labels(self.__class__)
|
||||||
|
label_to_output = self.get_label_to_output_name_mapping()
|
||||||
|
output_to_label = {val: key for key, val in label_to_output.items()}
|
||||||
if not self._using_dummy_loss:
|
if not self._using_dummy_loss:
|
||||||
data = data_adapter.expand_1d(data)
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
@@ -1046,7 +1090,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# if those keys are not already present in the input dict
|
# if those keys are not already present in the input dict
|
||||||
if self._using_dummy_loss and y is not None:
|
if self._using_dummy_loss and y is not None:
|
||||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
label_kwargs = find_labels(self.__class__)
|
|
||||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||||
if isinstance(x, tf.Tensor):
|
if isinstance(x, tf.Tensor):
|
||||||
@@ -1061,18 +1104,54 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
for key, val in y.items():
|
for key, val in y.items():
|
||||||
if key in arg_names and key not in x:
|
if key in arg_names and key not in x:
|
||||||
x[key] = val
|
x[key] = val
|
||||||
|
elif output_to_label.get(key, None) in arg_names and key not in x:
|
||||||
|
x[output_to_label[key]] = val
|
||||||
|
if y is None:
|
||||||
|
y = {key: val for key, val in x.items() if key in label_kwargs}
|
||||||
|
if not y and not self._using_dummy_loss:
|
||||||
|
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
|
||||||
|
|
||||||
|
if isinstance(y, dict):
|
||||||
|
# Rename labels at this point to match output heads
|
||||||
|
y = {label_to_output.get(key, key): val for key, val in y.items()}
|
||||||
|
|
||||||
# Run forward pass.
|
# Run forward pass.
|
||||||
y_pred = self(x, training=False)
|
y_pred = self(x, training=False)
|
||||||
if self._using_dummy_loss:
|
if self._using_dummy_loss:
|
||||||
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||||
else:
|
else:
|
||||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
loss = None
|
||||||
|
|
||||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
|
||||||
if self._using_dummy_loss:
|
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
|
||||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
if isinstance(y, dict) and len(y) == 1:
|
||||||
|
if list(y.keys())[0] in y_pred.keys():
|
||||||
|
y_pred = y_pred[list(y.keys())[0]]
|
||||||
|
elif list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred[1]
|
||||||
else:
|
else:
|
||||||
|
y_pred = y_pred[0]
|
||||||
|
_, y = y.popitem()
|
||||||
|
elif isinstance(y, dict):
|
||||||
|
# If the labels are a dict, match keys from the output by name
|
||||||
|
y_pred = {key: val for key, val in y_pred.items() if key in y}
|
||||||
|
elif isinstance(y, tuple) or isinstance(y, list):
|
||||||
|
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
|
||||||
|
if list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred.to_tuple()[1:]
|
||||||
|
else:
|
||||||
|
y_pred = y_pred.to_tuple()
|
||||||
|
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
|
||||||
|
else:
|
||||||
|
# If the labels are a single tensor, match them to the first non-loss tensor in the output
|
||||||
|
if list(y_pred.keys())[0] == "loss":
|
||||||
|
y_pred = y_pred[1]
|
||||||
|
else:
|
||||||
|
y_pred = y_pred[0]
|
||||||
|
|
||||||
|
if loss is None:
|
||||||
|
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
|
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
return_metrics = {}
|
return_metrics = {}
|
||||||
@@ -1082,10 +1161,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
return_metrics.update(result)
|
return_metrics.update(result)
|
||||||
else:
|
else:
|
||||||
return_metrics[metric.name] = result
|
return_metrics[metric.name] = result
|
||||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
|
||||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
|
||||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
|
||||||
del return_metrics["loss_loss"]
|
|
||||||
return return_metrics
|
return return_metrics
|
||||||
|
|
||||||
def create_model_card(
|
def create_model_card(
|
||||||
|
|||||||
@@ -1355,7 +1355,25 @@ class TFModelTesterMixin:
|
|||||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||||
self.assertGreater(len(inputs_minus_labels), 0)
|
self.assertGreater(len(inputs_minus_labels), 0)
|
||||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
accuracy_classes = [
|
||||||
|
"ForPreTraining",
|
||||||
|
"ForCausalLM",
|
||||||
|
"ForMaskedLM",
|
||||||
|
"ForQuestionAnswering",
|
||||||
|
"ForMultipleChoice",
|
||||||
|
"ForSequenceClassification",
|
||||||
|
"ForTokenClassification",
|
||||||
|
"ForNextSentencePrediction",
|
||||||
|
"LMHeadModel",
|
||||||
|
]
|
||||||
|
for accuracy_class in accuracy_classes:
|
||||||
|
if model.__class__.__name__.endswith(accuracy_class):
|
||||||
|
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
|
||||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||||
history1 = model.fit(
|
history1 = model.fit(
|
||||||
prepared_for_class,
|
prepared_for_class,
|
||||||
@@ -1365,6 +1383,7 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss1 = history1.history["val_loss"][0]
|
val_loss1 = history1.history["val_loss"][0]
|
||||||
|
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||||
history2 = model.fit(
|
history2 = model.fit(
|
||||||
inputs_minus_labels,
|
inputs_minus_labels,
|
||||||
labels,
|
labels,
|
||||||
@@ -1374,7 +1393,14 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss2 = history2.history["val_loss"][0]
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
|
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
|
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||||
|
for key in history1.history.keys():
|
||||||
|
if not key.startswith("val_"):
|
||||||
|
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
|
||||||
|
if metrics:
|
||||||
|
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||||
|
|
||||||
def test_int64_inputs(self):
|
def test_int64_inputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user