TF Model train and eval step metrics for seq2seq models. (#14009)
* TF Model train and eval step metrics for seq2seq models. When using a model with a seq2seq output compute metrics against logits. * Removing vestigial code Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from .file_utils import (
|
|||||||
is_remote_url,
|
is_remote_url,
|
||||||
)
|
)
|
||||||
from .generation_tf_utils import TFGenerationMixin
|
from .generation_tf_utils import TFGenerationMixin
|
||||||
|
from .modeling_tf_outputs import TFSeq2SeqLMOutput
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -787,6 +788,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
# Run backwards pass.
|
# Run backwards pass.
|
||||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||||
|
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
|
||||||
|
# should be done only with the relevant ModelOutput param that is
|
||||||
|
# considered by the loss.
|
||||||
|
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||||
|
y_pred = y_pred["logits"]
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
return_metrics = {}
|
return_metrics = {}
|
||||||
@@ -813,17 +819,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if y is None and "labels" in x:
|
if y is None and "labels" in x:
|
||||||
y = x["labels"] # Stops confusion with metric computations
|
y = x["labels"] # Stops confusion with metric computations
|
||||||
y_pred = self(x, training=False)
|
y_pred = self(x, training=False)
|
||||||
if not self.loss:
|
|
||||||
self.loss_tracker.update_state(y_pred.loss)
|
|
||||||
return_metrics = {"loss": self.loss_tracker.result()}
|
|
||||||
else:
|
|
||||||
# Run anyway to update state
|
|
||||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
|
||||||
return_metrics = {}
|
|
||||||
# Updates stateful loss metrics.
|
|
||||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
|
# Updates stateful loss metrics.
|
||||||
|
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||||
|
y_pred = y_pred["logits"]
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
|
return_metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
result = metric.result()
|
result = metric.result()
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
|
|||||||
@@ -666,3 +666,33 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
|||||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
self.assertEqual(translation, expected_translation)
|
self.assertEqual(translation, expected_translation)
|
||||||
|
|
||||||
|
def test_finetune_keras_trainer(self):
|
||||||
|
"""Ensure that the model can be fine-tuned via the keras API and
|
||||||
|
that metrics work as expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This metric expects to be called with the logits output
|
||||||
|
def _accuracy(y_true, y_pred):
|
||||||
|
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
|
||||||
|
|
||||||
|
# measure the accuracy of the first token
|
||||||
|
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
|
||||||
|
def __init__(self, name="accuracy", **kwargs):
|
||||||
|
super().__init__(_accuracy, name=name, **kwargs)
|
||||||
|
|
||||||
|
model = self.model
|
||||||
|
model.compile("adam", metrics=FirstTokenAccuracy())
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
("sentiment: Everything is awesome!", "positive"),
|
||||||
|
("sentiment: Tensorflow datasets are hard to use", "negative"),
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
|
||||||
|
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
|
||||||
|
|
||||||
|
model.fit(inputs)
|
||||||
|
m = model.evaluate(inputs)
|
||||||
|
self.assertEqual(len(m), 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user