XLA train step fixes (#17973)
* Copy inputs to train and test step before modifying them, as this breaks things * Add XLA tests, fix our loss functions to be XLA-compatible * make fixup * Update loss computation test to expect vector of per-sample losses * Patch loss for TFLED * Patch loss for TFAlbert * Add a tf_legacy_loss config flag that enables old loss functions * Stop using config.get() because it's not a dict * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it * make fixup * Add XLA-compatible RAG loss * Fix dtype of loss mask for TFAlbert * Fix test for XLNet too because it overrides the default one * make fixup * Fix config test * No more depending on GPU NaN behaviour * Add test, avoid potential zero division * Fix test item assignment * Fix loss computation masking test * make fixup * Fix dtype bugs
This commit is contained in:
@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
use_bfloat16 (`bool`, *optional*, defaults to `False`):
|
use_bfloat16 (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
|
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
|
||||||
|
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
|
||||||
|
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
|
||||||
|
v5.
|
||||||
"""
|
"""
|
||||||
model_type: str = ""
|
model_type: str = ""
|
||||||
is_composition: bool = False
|
is_composition: bool = False
|
||||||
@@ -260,6 +264,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||||
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
||||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||||
|
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
|
||||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||||
self.tie_word_embeddings = kwargs.pop(
|
self.tie_word_embeddings = kwargs.pop(
|
||||||
"tie_word_embeddings", True
|
"tie_word_embeddings", True
|
||||||
|
|||||||
@@ -195,11 +195,22 @@ class TFCausalLanguageModelingLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
|
if self.config.tf_legacy_loss:
|
||||||
|
# make sure only labels that are not equal to -100 affect the loss
|
||||||
|
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||||
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
return loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
|
||||||
# make sure only labels that are not equal to -100 affect the loss
|
# make sure only labels that are not equal to -100 affect the loss
|
||||||
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
# Avoid division by zero later
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
|
||||||
return loss_fn(labels, reduced_logits)
|
masked_loss = unmasked_loss * loss_mask
|
||||||
|
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
|
||||||
|
return reduced_masked_loss
|
||||||
|
|
||||||
|
|
||||||
class TFQuestionAnsweringLoss:
|
class TFQuestionAnsweringLoss:
|
||||||
@@ -232,17 +243,34 @@ class TFTokenClassificationLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
|
||||||
# are taken into account as loss
|
if tf.math.reduce_any(labels == -1):
|
||||||
if tf.math.reduce_any(labels == -1):
|
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||||
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
|
||||||
else:
|
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -100
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
|
||||||
|
|
||||||
return loss_fn(labels, reduced_logits)
|
if self.config.tf_legacy_loss:
|
||||||
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account as loss
|
||||||
|
if tf.math.reduce_any(labels == -1):
|
||||||
|
tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||||
|
else:
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||||
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|
||||||
|
return loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
|
||||||
|
# make sure only labels that are not equal to -100 or -1
|
||||||
|
# are taken into account as loss
|
||||||
|
loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
|
||||||
|
# Avoid possible division by zero later
|
||||||
|
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
|
||||||
|
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
|
||||||
|
masked_loss = unmasked_loss * loss_mask
|
||||||
|
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
|
||||||
|
return reduced_masked_loss
|
||||||
|
|
||||||
|
|
||||||
class TFSequenceClassificationLoss:
|
class TFSequenceClassificationLoss:
|
||||||
@@ -251,7 +279,7 @@ class TFSequenceClassificationLoss:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def hf_compute_loss(self, labels, logits):
|
def hf_compute_loss(self, labels, logits):
|
||||||
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
|
if logits.shape.rank == 1 or logits.shape[1] == 1:
|
||||||
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
||||||
else:
|
else:
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
@@ -298,13 +326,25 @@ class TFNextSentencePredictionLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
|
if self.config.tf_legacy_loss:
|
||||||
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account as loss
|
||||||
|
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||||
|
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
|
||||||
|
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
|
||||||
|
|
||||||
|
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
|
||||||
|
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100
|
||||||
# are taken into account as loss
|
# are taken into account as loss
|
||||||
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
|
||||||
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
|
|
||||||
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
|
|
||||||
|
|
||||||
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
|
||||||
|
ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
|
||||||
|
# Just zero out samples where label is -100, no reduction
|
||||||
|
masked_ns_loss = unmasked_ns_loss * ns_loss_mask
|
||||||
|
|
||||||
|
return masked_ns_loss
|
||||||
|
|
||||||
|
|
||||||
def booleans_processing(config, **kwargs):
|
def booleans_processing(config, **kwargs):
|
||||||
@@ -1327,6 +1367,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if not self._using_dummy_loss:
|
if not self._using_dummy_loss:
|
||||||
data = data_adapter.expand_1d(data)
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
|
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
||||||
|
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
||||||
|
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = x.copy()
|
||||||
|
if isinstance(y, dict):
|
||||||
|
y = y.copy()
|
||||||
|
|
||||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||||
# if those keys are not already present in the input dict
|
# if those keys are not already present in the input dict
|
||||||
@@ -1424,6 +1471,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
if not self._using_dummy_loss:
|
if not self._using_dummy_loss:
|
||||||
data = data_adapter.expand_1d(data)
|
data = data_adapter.expand_1d(data)
|
||||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||||
|
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
|
||||||
|
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
|
||||||
|
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
|
||||||
|
if isinstance(x, dict):
|
||||||
|
x = x.copy()
|
||||||
|
if isinstance(y, dict):
|
||||||
|
y = y.copy()
|
||||||
|
|
||||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||||
# if those keys are not already present in the input dict
|
# if those keys are not already present in the input dict
|
||||||
|
|||||||
@@ -86,29 +86,52 @@ class TFAlbertPreTrainingLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
if self.config.tf_legacy_loss:
|
||||||
# are taken into account as loss
|
# make sure only labels that are not equal to -100
|
||||||
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
|
# are taken into account as loss
|
||||||
masked_lm_reduced_logits = tf.boolean_mask(
|
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
|
||||||
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
|
masked_lm_reduced_logits = tf.boolean_mask(
|
||||||
mask=masked_lm_active_loss,
|
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
|
||||||
)
|
mask=masked_lm_active_loss,
|
||||||
masked_lm_labels = tf.boolean_mask(
|
)
|
||||||
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
|
masked_lm_labels = tf.boolean_mask(
|
||||||
)
|
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
|
||||||
sentence_order_active_loss = tf.not_equal(tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100)
|
)
|
||||||
sentence_order_reduced_logits = tf.boolean_mask(
|
sentence_order_active_loss = tf.not_equal(
|
||||||
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
|
tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
|
||||||
)
|
)
|
||||||
sentence_order_label = tf.boolean_mask(
|
sentence_order_reduced_logits = tf.boolean_mask(
|
||||||
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
|
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
|
||||||
)
|
)
|
||||||
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
|
sentence_order_label = tf.boolean_mask(
|
||||||
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
|
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
|
||||||
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
|
)
|
||||||
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
|
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
|
||||||
|
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
|
||||||
|
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
|
||||||
|
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
|
||||||
|
|
||||||
return masked_lm_loss + sentence_order_loss
|
return masked_lm_loss + sentence_order_loss
|
||||||
|
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
|
||||||
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account for the loss computation
|
||||||
|
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
||||||
|
# Avoid division by zero later
|
||||||
|
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
|
||||||
|
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
|
||||||
|
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
|
||||||
|
|
||||||
|
sop_logits = tf.reshape(logits[1], (-1, 2))
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
|
||||||
|
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
|
||||||
|
|
||||||
|
# No reduction because this already has shape (num_samples,)
|
||||||
|
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
|
||||||
|
|
||||||
|
return reduced_masked_lm_loss + masked_sop_loss
|
||||||
|
|
||||||
|
|
||||||
class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
||||||
|
|||||||
@@ -124,18 +124,22 @@ class TFBertPreTrainingLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
unmasked_lm_losses = loss_fn(y_true=labels["labels"], y_pred=logits[0])
|
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100
|
||||||
# are taken into account for the loss computation
|
# are taken into account for the loss computation
|
||||||
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
||||||
lm_loss_denominator = tf.reduce_sum(lm_loss_mask, axis=1)
|
# Avoid potential division by zero later
|
||||||
masked_lm_losses = tf.math.multiply_no_nan(unmasked_lm_losses, lm_loss_mask)
|
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
|
||||||
|
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
|
||||||
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
|
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
|
||||||
|
|
||||||
unmasked_ns_loss = loss_fn(y_true=labels["next_sentence_label"], y_pred=logits[1])
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
|
||||||
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
|
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
|
||||||
# Just zero out samples where label is -100, no reduction
|
# Just zero out samples where label is -100, no reduction
|
||||||
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)
|
masked_ns_loss = unmasked_ns_loss * ns_loss_mask
|
||||||
|
|
||||||
return reduced_masked_lm_loss + masked_ns_loss
|
return reduced_masked_lm_loss + masked_ns_loss
|
||||||
|
|
||||||
|
|||||||
@@ -2505,11 +2505,20 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||||||
def hf_compute_loss(self, labels, logits):
|
def hf_compute_loss(self, labels, logits):
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
"""CrossEntropyLoss that ignores pad tokens"""
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True,
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
reduction=tf.keras.losses.Reduction.NONE,
|
|
||||||
)
|
)
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
if self.config.tf_legacy_loss:
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
melted_labels = tf.reshape(labels, (-1,))
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
return loss_fn(labels, reduced_logits)
|
labels = tf.boolean_mask(melted_labels, active_loss)
|
||||||
|
return loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
||||||
|
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
|
||||||
|
# make sure only non-padding labels affect the loss
|
||||||
|
loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
|
||||||
|
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
|
||||||
|
masked_loss = unmasked_loss * loss_mask
|
||||||
|
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
|
||||||
|
return reduced_masked_loss
|
||||||
|
|||||||
@@ -1333,27 +1333,46 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||||||
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
|
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
|
||||||
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
|
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
|
||||||
"""CrossEntropyLoss that ignores pad tokens"""
|
"""CrossEntropyLoss that ignores pad tokens"""
|
||||||
|
if self.config.tf_legacy_loss:
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True,
|
||||||
|
reduction=tf.keras.losses.Reduction.SUM,
|
||||||
|
)
|
||||||
|
|
||||||
|
if from_logits is False: # convert to logits
|
||||||
|
eps = 1e-9
|
||||||
|
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
|
||||||
|
y_pred = tf.math.log(y_pred)
|
||||||
|
|
||||||
|
logits = y_pred
|
||||||
|
melted_labels = tf.reshape(labels, (-1,))
|
||||||
|
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
|
||||||
|
|
||||||
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
|
||||||
|
labels = tf.boolean_mask(melted_labels, active_loss)
|
||||||
|
nll_loss = loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
|
||||||
|
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
|
||||||
|
eps_i = smooth_epsilon / reduced_logits.shape[-1]
|
||||||
|
|
||||||
|
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True,
|
from_logits=from_logits,
|
||||||
reduction=tf.keras.losses.Reduction.SUM,
|
reduction=tf.keras.losses.Reduction.NONE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if from_logits is False: # convert to logits
|
unmasked_loss = loss_fn(labels, y_pred)
|
||||||
eps = 1e-9
|
loss_mask = labels != self.config.generator.pad_token_id
|
||||||
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
|
nll_loss = tf.reduce_sum(unmasked_loss * loss_mask)
|
||||||
y_pred = tf.math.log(y_pred)
|
|
||||||
|
|
||||||
logits = y_pred
|
# Matt: This makes no sense to me, but I'm just copying the old loss in XLA-compatible form
|
||||||
melted_labels = tf.reshape(labels, (-1,))
|
smooth_loss = -tf.reduce_sum(y_pred * tf.expand_dims(labels, -1), axis=-1)
|
||||||
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)
|
smooth_loss = tf.reduce_sum(smooth_loss)
|
||||||
|
eps_i = smooth_epsilon / y_pred.shape[-1]
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(melted_labels, active_loss)
|
|
||||||
nll_loss = loss_fn(labels, reduced_logits)
|
|
||||||
|
|
||||||
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
|
|
||||||
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
|
|
||||||
eps_i = smooth_epsilon / reduced_logits.shape[-1]
|
|
||||||
|
|
||||||
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
|
loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss
|
||||||
|
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
added_label = prepared_for_class[
|
added_label = prepared_for_class[
|
||||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||||
]
|
]
|
||||||
loss_size = tf.size(added_label)
|
expected_loss_size = added_label.shape.as_list()[:1]
|
||||||
|
|
||||||
# `TFXLNetLMHeadModel` doesn't cut logits/labels
|
# `TFXLNetLMHeadModel` doesn't cut logits/labels
|
||||||
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
# if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
||||||
@@ -417,12 +417,12 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
input_ids = prepared_for_class.pop(input_name)
|
input_ids = prepared_for_class.pop(input_name)
|
||||||
|
|
||||||
loss = model(input_ids, **prepared_for_class)[0]
|
loss = model(input_ids, **prepared_for_class)[0]
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
# Test that model correctly compute the loss with a dict
|
# Test that model correctly compute the loss with a dict
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
loss = model(prepared_for_class)[0]
|
loss = model(prepared_for_class)[0]
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
# Test that model correctly compute the loss with a tuple
|
# Test that model correctly compute the loss with a tuple
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
@@ -453,7 +453,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# Send to model
|
# Send to model
|
||||||
loss = model(tuple_input[:-1])[0]
|
loss = model(tuple_input[:-1])[0]
|
||||||
|
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ config_common_kwargs = {
|
|||||||
"torchscript": True,
|
"torchscript": True,
|
||||||
"torch_dtype": "float16",
|
"torch_dtype": "float16",
|
||||||
"use_bfloat16": True,
|
"use_bfloat16": True,
|
||||||
|
"tf_legacy_loss": True,
|
||||||
"pruned_heads": {"a": 1},
|
"pruned_heads": {"a": 1},
|
||||||
"tie_word_embeddings": False,
|
"tie_word_embeddings": False,
|
||||||
"is_decoder": True,
|
"is_decoder": True,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from math import isnan
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -1284,12 +1285,7 @@ class TFModelTesterMixin:
|
|||||||
added_label = prepared_for_class[
|
added_label = prepared_for_class[
|
||||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||||
]
|
]
|
||||||
loss_size = tf.size(added_label)
|
expected_loss_size = added_label.shape.as_list()[:1]
|
||||||
|
|
||||||
if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
|
|
||||||
# if loss is causal lm loss, labels are shift, so that one label per batch
|
|
||||||
# is cut
|
|
||||||
loss_size = loss_size - self.model_tester.batch_size
|
|
||||||
|
|
||||||
# Test that model correctly compute the loss with kwargs
|
# Test that model correctly compute the loss with kwargs
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
@@ -1298,12 +1294,26 @@ class TFModelTesterMixin:
|
|||||||
model_input = prepared_for_class.pop(input_name)
|
model_input = prepared_for_class.pop(input_name)
|
||||||
|
|
||||||
loss = model(model_input, **prepared_for_class)[0]
|
loss = model(model_input, **prepared_for_class)[0]
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
|
# Test that model correctly compute the loss when we mask some positions
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
possible_input_names = {"input_ids", "pixel_values", "input_features"}
|
||||||
|
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
|
||||||
|
model_input = prepared_for_class.pop(input_name)
|
||||||
|
if "labels" in prepared_for_class:
|
||||||
|
labels = prepared_for_class["labels"].numpy()
|
||||||
|
if len(labels.shape) > 1 and labels.shape[1] != 1:
|
||||||
|
labels[0] = -100
|
||||||
|
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
|
||||||
|
loss = model(model_input, **prepared_for_class)[0]
|
||||||
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
self.assertTrue(not np.any(np.isnan(loss.numpy())))
|
||||||
|
|
||||||
# Test that model correctly compute the loss with a dict
|
# Test that model correctly compute the loss with a dict
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
loss = model(prepared_for_class)[0]
|
loss = model(prepared_for_class)[0]
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
# Test that model correctly compute the loss with a tuple
|
# Test that model correctly compute the loss with a tuple
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
@@ -1334,7 +1344,7 @@ class TFModelTesterMixin:
|
|||||||
# Send to model
|
# Send to model
|
||||||
loss = model(tuple_input[:-1])[0]
|
loss = model(tuple_input[:-1])[0]
|
||||||
|
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape.as_list(), expected_loss_size)
|
||||||
|
|
||||||
def test_keras_fit(self):
|
def test_keras_fit(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -1397,6 +1407,7 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss1 = history1.history["val_loss"][0]
|
val_loss1 = history1.history["val_loss"][0]
|
||||||
|
self.assertTrue(not isnan(val_loss1))
|
||||||
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
|
||||||
|
|
||||||
# We reinitialize the model here even though our learning rate was zero
|
# We reinitialize the model here even though our learning rate was zero
|
||||||
@@ -1412,6 +1423,7 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss2 = history2.history["val_loss"][0]
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
|
self.assertTrue(not isnan(val_loss2))
|
||||||
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||||
@@ -1437,6 +1449,7 @@ class TFModelTesterMixin:
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
val_loss3 = history3.history["val_loss"][0]
|
val_loss3 = history3.history["val_loss"][0]
|
||||||
|
self.assertTrue(not isnan(val_loss3))
|
||||||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
|
||||||
self.assertEqual(history1.history.keys(), history3.history.keys())
|
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from math import isnan
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
@@ -134,6 +135,72 @@ class TFCoreModelTesterMixin:
|
|||||||
outputs = run_in_graph_mode()
|
outputs = run_in_graph_mode()
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_xla_fit(self):
|
||||||
|
# This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
if getattr(model, "hf_compute_loss", None):
|
||||||
|
# Test that model correctly compute the loss with kwargs
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
# Is there a better way to remove these decoder inputs?
|
||||||
|
prepared_for_class = {
|
||||||
|
key: val
|
||||||
|
for key, val in prepared_for_class.items()
|
||||||
|
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||||
|
}
|
||||||
|
|
||||||
|
possible_label_cols = {
|
||||||
|
"labels",
|
||||||
|
"label",
|
||||||
|
"label_ids",
|
||||||
|
"start_positions",
|
||||||
|
"start_position",
|
||||||
|
"end_positions",
|
||||||
|
"end_position",
|
||||||
|
"next_sentence_label",
|
||||||
|
}
|
||||||
|
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||||
|
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||||
|
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||||
|
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||||
|
self.assertGreater(len(inputs_minus_labels), 0)
|
||||||
|
|
||||||
|
# Make sure it works with XLA!
|
||||||
|
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
|
||||||
|
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||||
|
history = model.fit(
|
||||||
|
prepared_for_class,
|
||||||
|
validation_data=prepared_for_class,
|
||||||
|
steps_per_epoch=1,
|
||||||
|
validation_steps=1,
|
||||||
|
shuffle=False,
|
||||||
|
verbose=0,
|
||||||
|
)
|
||||||
|
loss = history.history["loss"][0]
|
||||||
|
self.assertTrue(not isnan(loss))
|
||||||
|
val_loss = history.history["val_loss"][0]
|
||||||
|
self.assertTrue(not isnan(val_loss))
|
||||||
|
|
||||||
|
# Now test it with separate labels, to make sure that path works in XLA too.
|
||||||
|
model = model_class(config)
|
||||||
|
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), jit_compile=True)
|
||||||
|
history = model.fit(
|
||||||
|
inputs_minus_labels,
|
||||||
|
labels,
|
||||||
|
validation_data=(inputs_minus_labels, labels),
|
||||||
|
steps_per_epoch=1,
|
||||||
|
validation_steps=1,
|
||||||
|
shuffle=False,
|
||||||
|
verbose=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = history.history["loss"][0]
|
||||||
|
self.assertTrue(not isnan(loss))
|
||||||
|
val_loss = history.history["val_loss"][0]
|
||||||
|
self.assertTrue(not isnan(val_loss))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user