* Compute loss independent from decoder (as 14139) * fix expected seq_len + style * Apply the same change to TFVisionEncoderDecoderModel * fix style * Add case with labels in equivalence test * uncomment * Add case with labels in equivalence test * add decoder_token_labels * use hf_compute_loss * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copied from Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -29,7 +30,13 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
|
from ...modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
get_initializer,
|
||||||
|
input_processing,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||||
@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
||||||
|
|
||||||
|
DEPRECATION_WARNING = (
|
||||||
|
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
|
||||||
|
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
|
||||||
|
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
|
||||||
|
"need to pass them yourself anymore."
|
||||||
|
)
|
||||||
|
|
||||||
ENCODER_DECODER_START_DOCSTRING = r"""
|
ENCODER_DECODER_START_DOCSTRING = r"""
|
||||||
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
|
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
|
||||||
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
|
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
|
||||||
@@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
|
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||||
|
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||||
|
|
||||||
|
if decoder_start_token_id is None:
|
||||||
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||||
|
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
|
||||||
|
|
||||||
|
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||||
|
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||||
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
|
shifted_input_ids = tf.where(
|
||||||
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
# "Verify that `labels` has only positive values and -100"
|
||||||
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
|
with tf.control_dependencies([assert_gte0]):
|
||||||
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
||||||
class TFEncoderDecoderModel(TFPreTrainedModel):
|
class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
r"""
|
r"""
|
||||||
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
|
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
|
||||||
of the base model classes of the library as encoder and another one as decoder when created with the
|
of the base model classes of the library as encoder and another one as decoder when created with the
|
||||||
@@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
):
|
):
|
||||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||||
|
decoder_input_ids = shift_tokens_right(
|
||||||
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
|
||||||
decoder_processing_inputs = {
|
decoder_processing_inputs = {
|
||||||
"func": self.decoder.call,
|
"func": self.decoder.call,
|
||||||
"config": self.decoder.config,
|
"config": self.decoder.config,
|
||||||
@@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
"encoder_hidden_states": encoder_hidden_states,
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
"encoder_attention_mask": attention_mask,
|
"encoder_attention_mask": attention_mask,
|
||||||
"inputs_embeds": decoder_inputs_embeds,
|
"inputs_embeds": decoder_inputs_embeds,
|
||||||
"labels": labels,
|
|
||||||
"output_attentions": output_attentions,
|
"output_attentions": output_attentions,
|
||||||
"output_hidden_states": output_hidden_states,
|
"output_hidden_states": output_hidden_states,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
@@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||||
decoder_outputs = self.decoder(**decoder_inputs)
|
decoder_outputs = self.decoder(**decoder_inputs)
|
||||||
|
|
||||||
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
|
logits = decoder_outputs[0]
|
||||||
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
|
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||||
|
loss = self.hf_compute_loss(labels, logits)
|
||||||
|
|
||||||
|
past_key_values = None
|
||||||
if decoder_inputs["use_cache"]:
|
if decoder_inputs["use_cache"]:
|
||||||
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
|
past_key_values = decoder_outputs[1]
|
||||||
# The starting index of the remaining elements in `decoder_outputs`
|
# The starting index of the remaining elements in `decoder_outputs`
|
||||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||||
|
|
||||||
@@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=decoder_outputs.loss,
|
loss=loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
@@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def resize_token_embeddings(self, *args, **kwargs):
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
|
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -29,7 +30,13 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
|
from ...modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
get_initializer,
|
||||||
|
input_processing,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||||
@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
|
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
|
||||||
|
|
||||||
|
DEPRECATION_WARNING = (
|
||||||
|
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
|
||||||
|
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
|
||||||
|
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
|
||||||
|
"need to pass them yourself anymore."
|
||||||
|
)
|
||||||
|
|
||||||
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
|
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
|
||||||
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
|
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
|
||||||
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
|
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
|
||||||
@@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right
|
||||||
|
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
|
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||||
|
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||||
|
|
||||||
|
if decoder_start_token_id is None:
|
||||||
|
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||||
|
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
|
||||||
|
|
||||||
|
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||||
|
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||||
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
|
shifted_input_ids = tf.where(
|
||||||
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
# "Verify that `labels` has only positive values and -100"
|
||||||
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
|
with tf.control_dependencies([assert_gte0]):
|
||||||
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
|
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
|
||||||
class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
r"""
|
r"""
|
||||||
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
||||||
with one of the base vision model classes of the library as encoder and another one of the base model classes as
|
with one of the base vision model classes of the library as encoder and another one of the base model classes as
|
||||||
@@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
):
|
):
|
||||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||||
|
decoder_input_ids = shift_tokens_right(
|
||||||
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
|
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
|
||||||
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
|
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
|
||||||
|
|
||||||
@@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
"encoder_hidden_states": encoder_hidden_states,
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
"encoder_attention_mask": encoder_attention_mask,
|
"encoder_attention_mask": encoder_attention_mask,
|
||||||
"inputs_embeds": decoder_inputs_embeds,
|
"inputs_embeds": decoder_inputs_embeds,
|
||||||
"labels": labels,
|
|
||||||
"output_attentions": output_attentions,
|
"output_attentions": output_attentions,
|
||||||
"output_hidden_states": output_hidden_states,
|
"output_hidden_states": output_hidden_states,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
@@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||||
decoder_outputs = self.decoder(**decoder_inputs)
|
decoder_outputs = self.decoder(**decoder_inputs)
|
||||||
|
|
||||||
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
|
logits = decoder_outputs[0]
|
||||||
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
|
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||||
|
loss = self.hf_compute_loss(labels, logits)
|
||||||
|
|
||||||
|
past_key_values = None
|
||||||
if decoder_inputs["use_cache"]:
|
if decoder_inputs["use_cache"]:
|
||||||
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
|
past_key_values = decoder_outputs[1]
|
||||||
# The starting index of the remaining elements in `decoder_outputs`
|
# The starting index of the remaining elements in `decoder_outputs`
|
||||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||||
|
|
||||||
@@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=decoder_outputs.loss,
|
loss=loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
@@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||||
|
|
||||||
def resize_token_embeddings(self, *args, **kwargs):
|
def resize_token_embeddings(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
|
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Make sure `loss` exist
|
# Make sure `loss` exist
|
||||||
assert "loss" in outputs_encoder_decoder
|
self.assertIn("loss", outputs_encoder_decoder)
|
||||||
|
|
||||||
batch_size, seq_len = decoder_input_ids.shape
|
batch_size, seq_len = decoder_input_ids.shape
|
||||||
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||||
@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin:
|
|||||||
# prepare inputs
|
# prepare inputs
|
||||||
tf_inputs = inputs_dict
|
tf_inputs = inputs_dict
|
||||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||||
|
if "labels" in pt_inputs:
|
||||||
|
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
tf_outputs = tf_model(**inputs_dict)
|
||||||
|
if "loss" in tf_outputs:
|
||||||
|
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||||
|
tf_outputs = tf_outputs.to_tuple()
|
||||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin:
|
|||||||
# This is only for copying some specific attributes of this particular model.
|
# This is only for copying some specific attributes of this particular model.
|
||||||
tf_model_loaded.config = pt_model.config
|
tf_model_loaded.config = pt_model.config
|
||||||
|
|
||||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||||
|
if "loss" in tf_outputs_loaded:
|
||||||
|
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||||
|
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin:
|
|||||||
def test_pt_tf_equivalence(self):
|
def test_pt_tf_equivalence(self):
|
||||||
|
|
||||||
config_inputs_dict = self.prepare_config_and_inputs()
|
config_inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||||
|
|
||||||
# Keep only common arguments
|
# Keep only common arguments
|
||||||
arg_names = [
|
arg_names = [
|
||||||
"config",
|
"config",
|
||||||
@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin:
|
|||||||
# `encoder_hidden_states` is not used in model call/forward
|
# `encoder_hidden_states` is not used in model call/forward
|
||||||
del inputs_dict["encoder_hidden_states"]
|
del inputs_dict["encoder_hidden_states"]
|
||||||
|
|
||||||
|
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||||
|
inputs_dict_with_labels["labels"] = labels
|
||||||
|
|
||||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||||
@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin:
|
|||||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# check equivalence with labels
|
||||||
|
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||||
|
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||||
|
|
||||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||||
# which randomly initialize `enc_to_dec_proj`.
|
# which randomly initialize `enc_to_dec_proj`.
|
||||||
# # check `enc_to_dec_proj` work as expected
|
# # check `enc_to_dec_proj` work as expected
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
|
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
# prepare inputs
|
# prepare inputs
|
||||||
tf_inputs = inputs_dict
|
tf_inputs = inputs_dict
|
||||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||||
|
if "labels" in pt_inputs:
|
||||||
|
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
tf_outputs = tf_model(**inputs_dict)
|
||||||
|
if "loss" in tf_outputs:
|
||||||
|
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||||
|
tf_outputs = tf_outputs.to_tuple()
|
||||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
# This is only for copying some specific attributes of this particular model.
|
# This is only for copying some specific attributes of this particular model.
|
||||||
tf_model_loaded.config = pt_model.config
|
tf_model_loaded.config = pt_model.config
|
||||||
|
|
||||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||||
|
if "loss" in tf_outputs_loaded:
|
||||||
|
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||||
|
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
def test_pt_tf_equivalence(self):
|
def test_pt_tf_equivalence(self):
|
||||||
|
|
||||||
config_inputs_dict = self.prepare_config_and_inputs()
|
config_inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||||
|
|
||||||
# Keep only common arguments
|
# Keep only common arguments
|
||||||
arg_names = [
|
arg_names = [
|
||||||
"config",
|
"config",
|
||||||
@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
# `encoder_hidden_states` is not used in model call/forward
|
# `encoder_hidden_states` is not used in model call/forward
|
||||||
del inputs_dict["encoder_hidden_states"]
|
del inputs_dict["encoder_hidden_states"]
|
||||||
|
|
||||||
|
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||||
|
inputs_dict_with_labels["labels"] = labels
|
||||||
|
|
||||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||||
@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin:
|
|||||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# check equivalence with labels
|
||||||
|
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||||
|
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||||
|
|
||||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||||
# which randomly initialize `enc_to_dec_proj`.
|
# which randomly initialize `enc_to_dec_proj`.
|
||||||
# # check `enc_to_dec_proj` work as expected
|
# # check `enc_to_dec_proj` work as expected
|
||||||
@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
|
|||||||
"decoder_config": decoder_config,
|
"decoder_config": decoder_config,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"decoder_token_labels": decoder_token_labels,
|
||||||
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
|
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
|
||||||
"labels": decoder_token_labels,
|
"labels": decoder_token_labels,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user