* Compute loss independent from decoder (as 14139) * fix expected seq_len + style * Apply the same change to TFVisionEncoderDecoderModel * fix style * Add case with labels in equivalence test * uncomment * Add case with labels in equivalence test * add decoder_token_labels * use hf_compute_loss * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copied from Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -29,7 +30,13 @@ from ...file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
shape_list,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||
@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
|
||||
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
|
||||
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
|
||||
"need to pass them yourself anymore."
|
||||
)
|
||||
|
||||
ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
|
||||
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
|
||||
@@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||
|
||||
if decoder_start_token_id is None:
|
||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
|
||||
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = tf.where(
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
|
||||
class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
r"""
|
||||
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
|
||||
of the base model classes of the library as encoder and another one as decoder when created with the
|
||||
@@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
decoder_processing_inputs = {
|
||||
"func": self.decoder.call,
|
||||
"config": self.decoder.config,
|
||||
@@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": attention_mask,
|
||||
"inputs_embeds": decoder_inputs_embeds,
|
||||
"labels": labels,
|
||||
"output_attentions": output_attentions,
|
||||
"output_hidden_states": output_hidden_states,
|
||||
"use_cache": use_cache,
|
||||
@@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||
decoder_outputs = self.decoder(**decoder_inputs)
|
||||
|
||||
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
|
||||
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
|
||||
past_key_values = None
|
||||
logits = decoder_outputs[0]
|
||||
|
||||
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
past_key_values = None
|
||||
if decoder_inputs["use_cache"]:
|
||||
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
|
||||
@@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
return output
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
loss=loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
@@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -29,7 +30,13 @@ from ...file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
shape_list,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||
@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
|
||||
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
|
||||
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
|
||||
"need to pass them yourself anymore."
|
||||
)
|
||||
|
||||
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
|
||||
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
|
||||
@@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
|
||||
if pad_token_id is None:
|
||||
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
|
||||
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||
|
||||
if decoder_start_token_id is None:
|
||||
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
|
||||
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
|
||||
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = tf.where(
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
|
||||
class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
r"""
|
||||
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
|
||||
with one of the base vision model classes of the library as encoder and another one of the base model classes as
|
||||
@@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
|
||||
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
|
||||
|
||||
@@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"inputs_embeds": decoder_inputs_embeds,
|
||||
"labels": labels,
|
||||
"output_attentions": output_attentions,
|
||||
"output_hidden_states": output_hidden_states,
|
||||
"use_cache": use_cache,
|
||||
@@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
decoder_inputs = input_processing(**decoder_processing_inputs)
|
||||
decoder_outputs = self.decoder(**decoder_inputs)
|
||||
|
||||
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
|
||||
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
|
||||
past_key_values = None
|
||||
logits = decoder_outputs[0]
|
||||
|
||||
# Compute loss independent from decoder (as some shift the logits inside them)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||
loss = self.hf_compute_loss(labels, logits)
|
||||
|
||||
past_key_values = None
|
||||
if decoder_inputs["use_cache"]:
|
||||
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
|
||||
past_key_values = decoder_outputs[1]
|
||||
# The starting index of the remaining elements in `decoder_outputs`
|
||||
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
|
||||
|
||||
@@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
return output
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
loss=loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
@@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin:
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
assert "loss" in outputs_encoder_decoder
|
||||
self.assertIn("loss", outputs_encoder_decoder)
|
||||
|
||||
batch_size, seq_len = decoder_input_ids.shape
|
||||
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||
@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin:
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
if "labels" in pt_inputs:
|
||||
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
tf_outputs = tf_model(**inputs_dict)
|
||||
if "loss" in tf_outputs:
|
||||
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||
tf_outputs = tf_outputs.to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin:
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||
if "loss" in tf_outputs_loaded:
|
||||
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin:
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin:
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||
inputs_dict_with_labels["labels"] = labels
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin:
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check equivalence with labels
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin:
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
if "labels" in pt_inputs:
|
||||
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
tf_outputs = tf_model(**inputs_dict)
|
||||
if "loss" in tf_outputs:
|
||||
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||
tf_outputs = tf_outputs.to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin:
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||
if "loss" in tf_outputs_loaded:
|
||||
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin:
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin:
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||
inputs_dict_with_labels["labels"] = labels
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin:
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check equivalence with labels
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user