Adding new train_step logic to make things less confusing for users (#15994)
* Adding new train_step logic to make things less confusing for users * DO NOT ASK WHY WE NEED THAT SUBCLASS * Metrics now working, at least for single-output models with type annotations! * Updates and TODOs for the new train_step * Make fixup * Temporary test workaround until T5 has types * Temporary test workaround until T5 has types * I think this actually works! Needs a lot of tests though * MAke style/quality * Revert changes to T5 tests * Deleting the aforementioned unmentionable subclass * Deleting the aforementioned unmentionable subclass * Adding a Keras API test * Style fixes * Removing unneeded TODO and comments * Update test_step too * Stop trying to compute metrics with the dummy_loss, patch up test * Make style * make fixup * Docstring cleanup * make fixup * make fixup * Stop expanding 1D input tensors when using dummy loss * Adjust T5 test given the new compile() * make fixup * Skipping test for convnext * Removing old T5-specific Keras test now that we have a common one * make fixup * make fixup * Only skip convnext test on CPU * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Avoiding TF import issues * make fixup * Update compile() to support TF 2.3 * Skipping model.fit() on template classes for now * Skipping model.fit() on template class tests for now * Replace ad-hoc solution with find_labels * make fixup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -143,6 +143,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||
)
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ConvNext does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@@ -804,33 +804,3 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
def test_finetune_keras_trainer(self):
|
||||
"""Ensure that the model can be fine-tuned via the keras API and
|
||||
that metrics work as expected.
|
||||
"""
|
||||
|
||||
# This metric expects to be called with the logits output
|
||||
def _accuracy(y_true, y_pred):
|
||||
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
|
||||
|
||||
# measure the accuracy of the first token
|
||||
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
|
||||
def __init__(self, name="accuracy", **kwargs):
|
||||
super().__init__(_accuracy, name=name, **kwargs)
|
||||
|
||||
model = self.model
|
||||
model.compile("adam", metrics=FirstTokenAccuracy())
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
examples = [
|
||||
("sentiment: Everything is awesome!", "positive"),
|
||||
("sentiment: Tensorflow datasets are hard to use", "negative"),
|
||||
]
|
||||
|
||||
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
|
||||
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
|
||||
|
||||
model.fit(inputs)
|
||||
m = model.evaluate(inputs)
|
||||
self.assertEqual(len(m), 2)
|
||||
|
||||
@@ -1302,6 +1302,56 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
# Is there a better way to remove these decoder inputs?
|
||||
prepared_for_class = {
|
||||
key: val
|
||||
for key, val in prepared_for_class.items()
|
||||
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||
}
|
||||
|
||||
possible_label_cols = {
|
||||
"labels",
|
||||
"label",
|
||||
"label_ids",
|
||||
"start_positions",
|
||||
"start_position",
|
||||
"end_positions",
|
||||
"end_position",
|
||||
"next_sentence_label",
|
||||
}
|
||||
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def test_generate_with_headmasking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user