* Compute loss independent from decoder (as 14139) * fix expected seq_len + style * Apply the same change to TFVisionEncoderDecoderModel * fix style * Add case with labels in equivalence test * uncomment * Add case with labels in equivalence test * add decoder_token_labels * use hf_compute_loss * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copied from Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin:
|
||||
)
|
||||
|
||||
# Make sure `loss` exist
|
||||
assert "loss" in outputs_encoder_decoder
|
||||
self.assertIn("loss", outputs_encoder_decoder)
|
||||
|
||||
batch_size, seq_len = decoder_input_ids.shape
|
||||
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
|
||||
@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin:
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
if "labels" in pt_inputs:
|
||||
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
tf_outputs = tf_model(**inputs_dict)
|
||||
if "loss" in tf_outputs:
|
||||
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
|
||||
tf_outputs = tf_outputs.to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin:
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
|
||||
if "loss" in tf_outputs_loaded:
|
||||
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
|
||||
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin:
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
labels = config_inputs_dict.pop("decoder_token_labels")
|
||||
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin:
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
inputs_dict_with_labels = copy.copy(inputs_dict)
|
||||
inputs_dict_with_labels["labels"] = labels
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin:
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check equivalence with labels
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
|
||||
Reference in New Issue
Block a user