Adding new train_step logic to make things less confusing for users (#15994)
* Adding new train_step logic to make things less confusing for users * DO NOT ASK WHY WE NEED THAT SUBCLASS * Metrics now working, at least for single-output models with type annotations! * Updates and TODOs for the new train_step * Make fixup * Temporary test workaround until T5 has types * Temporary test workaround until T5 has types * I think this actually works! Needs a lot of tests though * MAke style/quality * Revert changes to T5 tests * Deleting the aforementioned unmentionable subclass * Deleting the aforementioned unmentionable subclass * Adding a Keras API test * Style fixes * Removing unneeded TODO and comments * Update test_step too * Stop trying to compute metrics with the dummy_loss, patch up test * Make style * make fixup * Docstring cleanup * make fixup * make fixup * Stop expanding 1D input tensors when using dummy loss * Adjust T5 test given the new compile() * make fixup * Skipping test for convnext * Removing old T5-specific Keras test now that we have a common one * make fixup * make fixup * Only skip convnext test on CPU * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Avoiding TF import issues * make fixup * Update compile() to support TF 2.3 * Skipping model.fit() on template classes for now * Skipping model.fit() on template class tests for now * Replace ad-hoc solution with find_labels * make fixup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1302,6 +1302,56 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
# Is there a better way to remove these decoder inputs?
|
||||
prepared_for_class = {
|
||||
key: val
|
||||
for key, val in prepared_for_class.items()
|
||||
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||
}
|
||||
|
||||
possible_label_cols = {
|
||||
"labels",
|
||||
"label",
|
||||
"label_ids",
|
||||
"start_positions",
|
||||
"start_position",
|
||||
"end_positions",
|
||||
"end_position",
|
||||
"next_sentence_label",
|
||||
}
|
||||
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def test_generate_with_headmasking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user