Adding new train_step logic to make things less confusing for users (#15994)
* Adding new train_step logic to make things less confusing for users * DO NOT ASK WHY WE NEED THAT SUBCLASS * Metrics now working, at least for single-output models with type annotations! * Updates and TODOs for the new train_step * Make fixup * Temporary test workaround until T5 has types * Temporary test workaround until T5 has types * I think this actually works! Needs a lot of tests though * MAke style/quality * Revert changes to T5 tests * Deleting the aforementioned unmentionable subclass * Deleting the aforementioned unmentionable subclass * Adding a Keras API test * Style fixes * Removing unneeded TODO and comments * Update test_step too * Stop trying to compute metrics with the dummy_loss, patch up test * Make style * make fixup * Docstring cleanup * make fixup * make fixup * Stop expanding 1D input tensors when using dummy loss * Adjust T5 test given the new compile() * make fixup * Skipping test for convnext * Removing old T5-specific Keras test now that we have a common one * make fixup * make fixup * Only skip convnext test on CPU * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Avoiding TF import issues * make fixup * Update compile() to support TF 2.3 * Skipping model.fit() on template classes for now * Skipping model.fit() on template class tests for now * Replace ad-hoc solution with find_labels * make fixup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -38,7 +38,6 @@ from .activations_tf import get_tf_activation
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation_tf_utils import TFGenerationMixin
|
from .generation_tf_utils import TFGenerationMixin
|
||||||
from .modeling_tf_outputs import TFSeq2SeqLMOutput
|
|
||||||
from .tf_utils import shape_list
|
from .tf_utils import shape_list
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -53,6 +52,7 @@ from .utils import (
|
|||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
find_labels,
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
@@ -715,6 +715,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
_auto_class = None
|
_auto_class = None
|
||||||
|
_using_dummy_loss = None
|
||||||
|
|
||||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||||
# (and avoid unnecessary warnings).
|
# (and avoid unnecessary warnings).
|
||||||
@@ -899,24 +900,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
function themselves.
|
function themselves.
|
||||||
"""
|
"""
|
||||||
if loss == "passthrough":
|
if loss == "passthrough":
|
||||||
|
if metrics is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Passing metrics as a dict is not supported when using the internal loss! "
|
||||||
|
"Please either compile the model with a loss, or remove the metrics argument. "
|
||||||
|
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
|
||||||
|
"loss."
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||||
"Please ensure your labels are passed as keys in the input dict so that they are "
|
"To disable this behaviour, please pass a loss argument, or explicitly pass "
|
||||||
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
|
"`loss=None` if you do not want your model to compute a loss."
|
||||||
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
|
)
|
||||||
|
loss = dummy_loss
|
||||||
|
self._using_dummy_loss = True
|
||||||
|
else:
|
||||||
|
self._using_dummy_loss = False
|
||||||
|
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
|
||||||
|
if "steps_per_execution" in parent_args:
|
||||||
|
super().compile(
|
||||||
|
optimizer=optimizer,
|
||||||
|
loss=loss,
|
||||||
|
metrics=metrics,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
weighted_metrics=weighted_metrics,
|
||||||
|
run_eagerly=run_eagerly,
|
||||||
|
steps_per_execution=steps_per_execution,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
super().compile(
|
||||||
|
optimizer=optimizer,
|
||||||
|
loss=loss,
|
||||||
|
metrics=metrics,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
weighted_metrics=weighted_metrics,
|
||||||
|
run_eagerly=run_eagerly,
|
||||||
|
experimental_steps_per_execution=steps_per_execution,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
loss = {"loss": dummy_loss}
|
|
||||||
super().compile(
|
|
||||||
optimizer=optimizer,
|
|
||||||
loss=loss,
|
|
||||||
metrics=metrics,
|
|
||||||
loss_weights=loss_weights,
|
|
||||||
weighted_metrics=weighted_metrics,
|
|
||||||
run_eagerly=run_eagerly,
|
|
||||||
steps_per_execution=steps_per_execution,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_loss(self, *args, **kwargs):
|
def compute_loss(self, *args, **kwargs):
|
||||||
if hasattr(tf.keras.Model, "compute_loss"):
|
if hasattr(tf.keras.Model, "compute_loss"):
|
||||||
@@ -935,40 +958,54 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""
|
"""
|
||||||
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
|
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
|
||||||
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. In
|
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
|
||||||
this case, it expects the same `data` as the original function (i.e. `(inputs, labels)`).
|
|
||||||
|
|
||||||
However, when the model is compiled without specifying the loss AND the expected label columns are passed as
|
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||||
part of the input dictionary, the loss is computed internally (inside the model class) and is used in the
|
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||||
backwards pass. In this case, `data` is a singleton tuple containing `(inputs,)`.
|
as keys in the input dictionary, or as normal Keras labels.
|
||||||
|
|
||||||
This is possible under the aforementioned circumstances because our overriden compile function can set an
|
|
||||||
additional loss function that reduces a `loss` output, and the model will output a `loss` component (notice the
|
|
||||||
name matching) containing the loss that was used to train the pre-trained model.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# These are the only transformations `Model.fit` applies to user-input
|
# These are the only transformations `Model.fit` applies to user-input
|
||||||
# data when a `tf.data.Dataset` is provided.
|
# data when a `tf.data.Dataset` is provided.
|
||||||
data = data_adapter.expand_1d(data)
|
if not self._using_dummy_loss:
|
||||||
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
|
||||||
# the input dict (and loss is computed internally)
|
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||||
if y is None and "labels" in x:
|
# if those keys are not already present in the input dict
|
||||||
y = x["labels"] # Stops confusion with metric computations
|
if self._using_dummy_loss and y is not None:
|
||||||
elif y is None and "input_ids" in x:
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
# Just make any kind of dummy array to make loss work
|
label_kwargs = find_labels(self.__class__)
|
||||||
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
|
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||||
|
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||||
|
if isinstance(x, tf.Tensor):
|
||||||
|
x = {arg_names[0]: x}
|
||||||
|
label_kwarg = next(iter(label_kwargs))
|
||||||
|
if label_kwarg not in x:
|
||||||
|
x[label_kwarg] = y
|
||||||
|
# Otherwise, copy keys from y to x as long as they weren't already present in x
|
||||||
|
elif isinstance(y, dict):
|
||||||
|
if isinstance(x, tf.Tensor):
|
||||||
|
x = {arg_names[0]: x}
|
||||||
|
for key, val in y.items():
|
||||||
|
if key in arg_names and key not in x:
|
||||||
|
x[key] = val
|
||||||
|
|
||||||
# Run forward pass.
|
# Run forward pass.
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
y_pred = self(x, training=True)
|
y_pred = self(x, training=True)
|
||||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
if self._using_dummy_loss:
|
||||||
|
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||||
|
else:
|
||||||
|
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
# Run backwards pass.
|
# Run backwards pass.
|
||||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||||
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
|
|
||||||
# should be done only with the relevant ModelOutput param that is
|
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||||
# considered by the loss.
|
if self._using_dummy_loss:
|
||||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||||
y_pred = y_pred["logits"]
|
else:
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
return_metrics = {}
|
return_metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
@@ -985,23 +1022,51 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
|
|
||||||
def test_step(self, data):
|
def test_step(self, data):
|
||||||
"""
|
"""
|
||||||
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
|
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
|
||||||
|
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
|
||||||
|
|
||||||
|
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||||
|
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||||
|
as keys in the input dictionary, or as normal Keras labels.
|
||||||
"""
|
"""
|
||||||
data = data_adapter.expand_1d(data)
|
# These are the only transformations `Model.fit` applies to user-input
|
||||||
|
# data when a `tf.data.Dataset` is provided.
|
||||||
|
if not self._using_dummy_loss:
|
||||||
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
|
||||||
# the input dict (and loss is computed internally)
|
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||||
if y is None and "labels" in x:
|
# if those keys are not already present in the input dict
|
||||||
y = x["labels"] # Stops confusion with metric computations
|
if self._using_dummy_loss and y is not None:
|
||||||
elif y is None and "input_ids" in x:
|
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
# Just make any kind of dummy array to make loss work
|
label_kwargs = find_labels(self.__class__)
|
||||||
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
|
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||||
|
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||||
|
if isinstance(x, tf.Tensor):
|
||||||
|
x = {arg_names[0]: x}
|
||||||
|
label_kwarg = next(iter(label_kwargs))
|
||||||
|
if label_kwarg not in x:
|
||||||
|
x[label_kwarg] = y
|
||||||
|
# Otherwise, copy keys from y to x as long as they weren't already present in x
|
||||||
|
elif isinstance(y, dict):
|
||||||
|
if isinstance(x, tf.Tensor):
|
||||||
|
x = {arg_names[0]: x}
|
||||||
|
for key, val in y.items():
|
||||||
|
if key in arg_names and key not in x:
|
||||||
|
x[key] = val
|
||||||
|
|
||||||
|
# Run forward pass.
|
||||||
y_pred = self(x, training=False)
|
y_pred = self(x, training=False)
|
||||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
if self._using_dummy_loss:
|
||||||
# Updates stateful loss metrics.
|
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
else:
|
||||||
y_pred = y_pred["logits"]
|
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
|
||||||
|
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||||
|
if self._using_dummy_loss:
|
||||||
|
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||||
|
else:
|
||||||
|
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||||
# Collect metrics to return
|
# Collect metrics to return
|
||||||
return_metrics = {}
|
return_metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
|||||||
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_causal_lm_model_past(
|
def create_and_check_causal_lm_model_past(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -597,6 +598,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Template classes interact badly with this test.")
|
||||||
|
def test_keras_fit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_causal_lm_base_model(self):
|
def test_causal_lm_base_model(self):
|
||||||
"""Test the base model of the causal LM model
|
"""Test the base model of the causal LM model
|
||||||
|
|
||||||
@@ -947,6 +952,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||||||
models_equal = False
|
models_equal = False
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Template classes interact badly with this test.")
|
||||||
|
def test_keras_fit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -143,6 +143,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||||
|
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||||
|
)
|
||||||
|
def test_keras_fit(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="ConvNext does not support input and output embeddings")
|
@unittest.skip(reason="ConvNext does not support input and output embeddings")
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -804,33 +804,3 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
|||||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
self.assertEqual(translation, expected_translation)
|
self.assertEqual(translation, expected_translation)
|
||||||
|
|
||||||
def test_finetune_keras_trainer(self):
|
|
||||||
"""Ensure that the model can be fine-tuned via the keras API and
|
|
||||||
that metrics work as expected.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# This metric expects to be called with the logits output
|
|
||||||
def _accuracy(y_true, y_pred):
|
|
||||||
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
|
|
||||||
|
|
||||||
# measure the accuracy of the first token
|
|
||||||
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
|
|
||||||
def __init__(self, name="accuracy", **kwargs):
|
|
||||||
super().__init__(_accuracy, name=name, **kwargs)
|
|
||||||
|
|
||||||
model = self.model
|
|
||||||
model.compile("adam", metrics=FirstTokenAccuracy())
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
||||||
|
|
||||||
examples = [
|
|
||||||
("sentiment: Everything is awesome!", "positive"),
|
|
||||||
("sentiment: Tensorflow datasets are hard to use", "negative"),
|
|
||||||
]
|
|
||||||
|
|
||||||
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
|
|
||||||
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
|
|
||||||
|
|
||||||
model.fit(inputs)
|
|
||||||
m = model.evaluate(inputs)
|
|
||||||
self.assertEqual(len(m), 2)
|
|
||||||
|
|||||||
@@ -1302,6 +1302,56 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
|
def test_keras_fit(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
if getattr(model, "hf_compute_loss", None):
|
||||||
|
# Test that model correctly compute the loss with kwargs
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
# Is there a better way to remove these decoder inputs?
|
||||||
|
prepared_for_class = {
|
||||||
|
key: val
|
||||||
|
for key, val in prepared_for_class.items()
|
||||||
|
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||||
|
}
|
||||||
|
|
||||||
|
possible_label_cols = {
|
||||||
|
"labels",
|
||||||
|
"label",
|
||||||
|
"label_ids",
|
||||||
|
"start_positions",
|
||||||
|
"start_position",
|
||||||
|
"end_positions",
|
||||||
|
"end_position",
|
||||||
|
"next_sentence_label",
|
||||||
|
}
|
||||||
|
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||||
|
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||||
|
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||||
|
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||||
|
self.assertGreater(len(inputs_minus_labels), 0)
|
||||||
|
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||||
|
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||||
|
history1 = model.fit(
|
||||||
|
prepared_for_class,
|
||||||
|
validation_data=prepared_for_class,
|
||||||
|
steps_per_epoch=1,
|
||||||
|
validation_steps=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
val_loss1 = history1.history["val_loss"][0]
|
||||||
|
history2 = model.fit(
|
||||||
|
inputs_minus_labels,
|
||||||
|
labels,
|
||||||
|
validation_data=(inputs_minus_labels, labels),
|
||||||
|
steps_per_epoch=1,
|
||||||
|
validation_steps=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
|
|
||||||
def test_generate_with_headmasking(self):
|
def test_generate_with_headmasking(self):
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user