* Compute loss independent from decoder (as 14139) * fix expected seq_len + style * Apply the same change to TFVisionEncoderDecoderModel * fix style * Add case with labels in equivalence test * uncomment * Add case with labels in equivalence test * add decoder_token_labels * use hf_compute_loss * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copied from Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin:
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
if "labels" in pt_inputs:
|
||||
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
tf_outputs = tf_model(**inputs_dict)
|
||||
if "loss" in tf_outputs:
|
||||
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||
tf_outputs = tf_outputs.to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin:
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||
if "loss" in tf_outputs_loaded:
|
||||
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin:
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin:
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||
inputs_dict_with_labels["labels"] = labels
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin:
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check equivalence with labels
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user