Use cross_attention_hidden_size in Encoder-Decoder models (#14378)
* add cross_attention_hidden_size to text-2-text encoder-decoder models (PT/Flax) * for TFEncoderDecoderModel * add equivalence test for TFEncoderDecoderModel * fix * fix failed equivalence tests * remove unused import * add detailed comment * Fix check_equivalence_tf_to_pt by using encoder/decoder * cleaning * Use cross_attention_hidden_size in speech-to-text * clean fast init logging msg in encoder decoder models * increase tol from 1e-5 to 1e-3 for tf test * style * style * make sure projection layer can run * remove type conversion + add check * fix conflict (config.output_hidden_size) * Remove TF -> PT in check_pt_tf_equivalence for TFEncoderDecoderModel Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,8 @@ from .test_modeling_tf_roberta import TFRobertaModelTester
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@@ -309,6 +311,90 @@ class TFEncoderDecoderMixin:
|
||||
)
|
||||
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
# PT -> TF
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model.config = pt_model.config
|
||||
|
||||
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||
|
||||
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
|
||||
# the encoder/decoder models.
|
||||
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
|
||||
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
|
||||
# (the change in `src/transformers/modeling_tf_utils.py`)
|
||||
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
# Make sure model is built
|
||||
_tf_model(**inputs_dict)
|
||||
|
||||
# Using `tf_model` to pass the test.
|
||||
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
|
||||
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
|
||||
# Make sure models are built
|
||||
encoder(encoder.dummy_inputs)
|
||||
decoder(decoder.dummy_inputs)
|
||||
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
pt_model.config = tf_model.config
|
||||
|
||||
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model(**input_ids_dict)
|
||||
@@ -341,6 +427,65 @@ class TFEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_config",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"encoder_hidden_states",
|
||||
]
|
||||
config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names}
|
||||
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
|
||||
)
|
||||
|
||||
# TF models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
# decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# Let's just check `enc_to_dec_proj` can run for now
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
model(**inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
||||
Reference in New Issue
Block a user