diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 6c58d60a30..d530884eee 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -18,6 +18,7 @@ import warnings from typing import Optional import torch +from torch import nn from torch.nn import CrossEntropyLoss from ...configuration_utils import PretrainedConfig @@ -25,6 +26,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM from .configuration_encoder_decoder import EncoderDecoderConfig @@ -181,13 +184,23 @@ class EncoderDecoderModel(PreTrainedModel): encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, ): - assert config is not None or ( - encoder is not None and decoder is not None - ), "Either a configuration or an Encoder and a decoder has to be provided" + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") if config is None: config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) else: - assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}" + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " + f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " + f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." + ) + # initialize with config super().__init__(config) @@ -218,9 +231,17 @@ class EncoderDecoderModel(PreTrainedModel): self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder - assert ( - self.encoder.get_output_embeddings() is None - ), "The encoder {} should not have a LM Head. Please use a model without LM Head" + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) # tie encoder, decoder weights if config set accordingly self.tie_weights() @@ -251,8 +272,12 @@ class EncoderDecoderModel(PreTrainedModel): @classmethod def from_pretrained(cls, *args, **kwargs): - # At the moment fast initialization is not supported - # for composite models + # At the moment fast initialization is not supported for composite models + if kwargs.get("_fast_init", False): + logger.warning( + "Fast initialization is currently not supported for EncoderDecoderModel. " + "Falling back to slow initialization..." + ) kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @@ -343,19 +368,18 @@ class EncoderDecoderModel(PreTrainedModel): # by the value of the flag `is_decoder` that we need to set correctly. encoder = kwargs_encoder.pop("model", None) if encoder is None: - assert ( - encoder_pretrained_model_name_or_path is not None - ), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined" - from ..auto.modeling_auto import AutoModel + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) if "config" not in kwargs_encoder: - from ..auto.configuration_auto import AutoConfig - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled." + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False @@ -366,18 +390,20 @@ class EncoderDecoderModel(PreTrainedModel): decoder = kwargs_decoder.pop("model", None) if decoder is None: - assert ( - decoder_pretrained_model_name_or_path is not None - ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" - from ..auto.modeling_auto import AutoModelForCausalLM + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) if "config" not in kwargs_decoder: - from ..auto.configuration_auto import AutoConfig - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -386,7 +412,11 @@ class EncoderDecoderModel(PreTrainedModel): if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) @@ -464,6 +494,13 @@ class EncoderDecoderModel(PreTrainedModel): encoder_hidden_states = encoder_outputs[0] + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index 22b87e5d62..5846610317 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput from ...modeling_flax_utils import FlaxPreTrainedModel from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM from .configuration_encoder_decoder import EncoderDecoderConfig @@ -227,9 +229,25 @@ class FlaxEncoderDecoderModule(nn.Module): self.encoder = encoder_module(encoder_config, dtype=self.dtype) self.decoder = decoder_module(decoder_config, dtype=self.dtype) + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Dense( + self.decoder.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), + dtype=self.dtype, + ) + else: + self.enc_to_dec_proj = None + def _get_encoder_module(self): return self.encoder + def _get_projection_module(self): + return self.enc_to_dec_proj + def _get_decoder_module(self): return self.decoder @@ -256,11 +274,17 @@ class FlaxEncoderDecoderModule(nn.Module): deterministic=deterministic, ) + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if self.enc_to_dec_proj is not None: + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -305,6 +329,15 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): if input_shape is None: input_shape = ((1, 1), (1, 1)) + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " + f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " + f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." + ) + module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) @@ -537,12 +570,22 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): else: mutable = False - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + def _decoder_forward( + module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs + ): + + projection_module = module._get_projection_module() decoder_module = module._get_decoder_module() + + # optionally project encoder_hidden_states + if projection_module is not None: + encoder_hidden_states = projection_module(encoder_hidden_states) + return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, + encoder_hidden_states, **kwargs, ) @@ -772,19 +815,18 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): # by the value of the flag `is_decoder` that we need to set correctly. encoder = kwargs_encoder.pop("model", None) if encoder is None: - assert ( - encoder_pretrained_model_name_or_path is not None - ), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined" - from ..auto.modeling_flax_auto import FlaxAutoModel + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) if "config" not in kwargs_encoder: - from ..auto.configuration_auto import AutoConfig - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled." + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False @@ -797,18 +839,20 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): decoder = kwargs_decoder.pop("model", None) if decoder is None: - assert ( - decoder_pretrained_model_name_or_path is not None - ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" - from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) if "config" not in kwargs_decoder: - from ..auto.configuration_auto import AutoConfig - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -817,7 +861,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index cf682d1393..43905fcb6a 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -23,12 +23,13 @@ import tensorflow as tf from ...configuration_utils import PretrainedConfig from ...file_utils import ( DUMMY_INPUTS, + ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput -from ...modeling_tf_utils import TFPreTrainedModel, input_processing +from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing from ...utils import logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM @@ -168,12 +169,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel): decoder: Optional[TFPreTrainedModel] = None, ): if config is None and (encoder is None or decoder is None): - raise ValueError("Either a configuration or an encoder and a decoder has to be provided") + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") if config is None: config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " + f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " + f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." + ) + # initialize with config super().__init__(config) @@ -200,8 +211,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel): self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = tf.keras.layers.Dense( + units=self.decoder.config.hidden_size, + kernel_initializer=get_initializer(config.encoder.initializer_range), + name="enc_to_dec_proj", + ) + if self.encoder.get_output_embeddings() is not None: - raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head") + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) @property def dummy_inputs(self): @@ -355,16 +379,16 @@ class TFEncoderDecoderModel(TFPreTrainedModel): if encoder is None: if encoder_pretrained_model_name_or_path is None: raise ValueError( - "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined" + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled." + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False @@ -387,15 +411,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel): if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -404,7 +431,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel): if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) kwargs_decoder["name"] = "decoder" @@ -485,6 +516,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel): argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + # Let the user be responsible for the expected format. + if encoder_outputs is not None: + if return_dict and not isinstance(encoder_outputs, ModelOutput): + raise ValueError( + "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of " + f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`." + ) + if encoder_outputs is None: encoder_processing_inputs = { @@ -518,6 +557,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel): encoder_hidden_states = encoder_outputs[0] + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + decoder_processing_inputs = { "func": self.decoder.call, "config": self.decoder.config, @@ -562,14 +608,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel): output = tuple([x for x in output if x is not None]) return output - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - if not isinstance(encoder_outputs, TFBaseModelOutput): - encoder_outputs = TFBaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - return TFSeq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 98a41b3c55..75c939d906 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -195,6 +195,15 @@ class SpeechEncoderDecoderModel(PreTrainedModel): if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " + f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " + f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." + ) + # initialize with config # make sure input & output embeddings is not tied config.tie_word_embeddings = False @@ -225,7 +234,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel): # get encoder output hidden size self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size) - if self.encoder_output_dim != self.decoder.config.hidden_size: + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): # encoder outputs might need to be projected to different dimension for decoder self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) @@ -248,11 +260,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel): @classmethod def from_pretrained(cls, *args, **kwargs): - # At the moment fast initialization is not supported - # for composite models + # At the moment fast initialization is not supported for composite models if kwargs.get("_fast_init", False): logger.warning( - "Fast initialization is currently not supported for SpeechEncoderDecoderModel. Falling back to slow intialization..." + "Fast initialization is currently not supported for SpeechEncoderDecoderModel. " + "Falling back to slow initialization..." ) kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @@ -346,13 +358,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel): if encoder is None: if encoder_pretrained_model_name_or_path is None: raise ValueError( - f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. In this case make sure that `encoder_pretrained_model_name_or_path` defined" + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_encoder: encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " "from a decoder model. Cross-attention and casual mask are disabled." @@ -368,7 +380,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_decoder: @@ -376,8 +389,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel): if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " - "Cross attention layers are added to {decoder_pretrained_model_name_or_path} " - "and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -389,7 +403,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) @@ -472,8 +487,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel): encoder_hidden_states = encoder_outputs[0] - # project encoder_hidden_states - if self.encoder_output_dim != self.decoder.config.hidden_size: + # optionally project encoder_hidden_states + if ( + self.encoder_output_dim != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # compute correct encoder attention mask diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 09a23de1d6..85d1aa732b 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput from ...modeling_flax_utils import FlaxPreTrainedModel from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig @@ -301,8 +303,8 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: raise ValueError( - f"If `cross_attention_hidden_size` is specified in the decoder's configuration, " - f"it has to be equal to the encoder's `hidden_size`." + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." ) @@ -781,19 +783,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): if encoder_pretrained_model_name_or_path is None: raise ValueError( "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " - "to be defined" + "to be defined." ) - from ..auto.modeling_flax_auto import FlaxAutoModel if "config" not in kwargs_encoder: - from ..auto.configuration_auto import AutoConfig - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( - f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder " - "model. Cross-attention and casual mask are disabled." + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False @@ -811,17 +809,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " "to be defined." ) - from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM if "config" not in kwargs_decoder: - from ..auto.configuration_auto import AutoConfig - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention " - f"layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if " - f"{decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -830,11 +826,11 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order " - f"to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the " - "attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to " - "`.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to " - "`.from_encoder_decoder_pretrained(...)`" + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 43e18a87a3..75c1bcd3b0 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -178,8 +178,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: raise ValueError( - f"If `cross_attention_hidden_size` is specified in the decoder's configuration, " - f"it has to be equal to the encoder's `hidden_size`." + "If `cross_attention_hidden_size` is specified in the decoder's configuration, " + "it has to be equal to the encoder's `hidden_size`. " f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." ) @@ -241,7 +241,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): # At the moment fast initialization is not supported for composite models if kwargs.get("_fast_init", False): logger.warning( - "Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..." + "Fast initialization is currently not supported for VisionEncoderDecoderModel. " + "Falling back to slow initialization..." ) kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @@ -334,14 +335,13 @@ class VisionEncoderDecoderModel(PreTrainedModel): if encoder is None: if encoder_pretrained_model_name_or_path is None: raise ValueError( - f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. " - f"In this case make sure that `encoder_pretrained_model_name_or_path` defined" + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_encoder: encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: - logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " "from a decoder model. Cross-attention and casual mask are disabled." @@ -357,16 +357,18 @@ class VisionEncoderDecoderModel(PreTrainedModel): if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( - "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." ) if "config" not in kwargs_decoder: decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( - f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model." - "Cross attention layers are added to {decoder_pretrained_model_name_or_path} " - "and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " + f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " + f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " + "cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True @@ -375,11 +377,11 @@ class VisionEncoderDecoderModel(PreTrainedModel): if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( - f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder." + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " - "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config`" - "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` " - f"to `.from_encoder_decoder_pretrained(...)`" + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) diff --git a/tests/test_modeling_flax_encoder_decoder.py b/tests/test_modeling_flax_encoder_decoder.py index d6ffc9e896..267329f29d 100644 --- a/tests/test_modeling_flax_encoder_decoder.py +++ b/tests/test_modeling_flax_encoder_decoder.py @@ -19,8 +19,8 @@ import unittest import numpy as np -from transformers import is_flax_available -from transformers.testing_utils import require_flax, slow +from transformers import is_flax_available, is_torch_available +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device from .test_modeling_flax_bert import FlaxBertModelTester from .test_modeling_flax_common import ids_tensor @@ -35,6 +35,15 @@ if is_flax_available(): FlaxEncoderDecoderModel, FlaxGPT2LMHeadModel, ) + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + +if is_torch_available(): + import torch + + from transformers import EncoderDecoderModel @require_flax @@ -234,6 +243,71 @@ class FlaxEncoderDecoderMixin: generated_sequences = generated_output.sequences self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) + def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): + + pt_model.to(torch_device) + pt_model.eval() + + # prepare inputs + flax_inputs = inputs_dict + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) + + # PT -> Flax + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) + + # Flax -> PT + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) + + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) + + def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): + + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + + pt_model = EncoderDecoderModel(encoder_decoder_config) + fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) + + def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict): + + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + + pt_model = EncoderDecoderModel(encoder_decoder_config) + fx_model = FlaxEncoderDecoderModel(encoder_decoder_config) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) + def test_encoder_decoder_model_from_pretrained_configs(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) @@ -258,6 +332,44 @@ class FlaxEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + + @is_pt_flax_cross_test + def test_pt_flax_equivalence(self): + + config_inputs_dict = self.prepare_config_and_inputs() + config = config_inputs_dict.pop("config") + decoder_config = config_inputs_dict.pop("decoder_config") + + inputs_dict = config_inputs_dict + # `encoder_hidden_states` is not used in model call/forward + del inputs_dict["encoder_hidden_states"] + + # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) + batch_size = inputs_dict["decoder_attention_mask"].shape[0] + inputs_dict["decoder_attention_mask"] = np.concatenate( + [np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1 + ) + + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + decoder_config.use_cache = False + + self.assertTrue(decoder_config.cross_attention_hidden_size is None) + + # check without `enc_to_dec_proj` projection + self.assertTrue(config.hidden_size == decoder_config.hidden_size) + self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) + self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) + + # check `enc_to_dec_proj` work as expected + decoder_config.hidden_size = decoder_config.hidden_size * 2 + self.assertTrue(config.hidden_size != decoder_config.hidden_size) + self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict) + self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict) + @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model() diff --git a/tests/test_modeling_tf_encoder_decoder.py b/tests/test_modeling_tf_encoder_decoder.py index 7773711657..c13c5ede18 100644 --- a/tests/test_modeling_tf_encoder_decoder.py +++ b/tests/test_modeling_tf_encoder_decoder.py @@ -31,6 +31,8 @@ from .test_modeling_tf_roberta import TFRobertaModelTester if is_tf_available(): + import tensorflow as tf + from transformers import ( AutoConfig, AutoTokenizer, @@ -309,6 +311,90 @@ class TFEncoderDecoderMixin: ) self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,)) + def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict): + + pt_model.to(torch_device) + pt_model.eval() + + # prepare inputs + tf_inputs = inputs_dict + pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + tf_outputs = tf_model(**inputs_dict).to_tuple() + self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output, pt_output in zip(tf_outputs, pt_outputs): + self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) + + # PT -> TF + with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: + + pt_model.encoder.save_pretrained(encoder_tmp_dirname) + pt_model.decoder.save_pretrained(decoder_tmp_dirname) + tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True + ) + # This is only for copying some specific attributes of this particular model. + tf_model_loaded.config = pt_model.config + + tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple() + self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") + for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): + self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) + + def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict): + + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + + pt_model = EncoderDecoderModel(encoder_decoder_config) + + with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: + + pt_model.encoder.save_pretrained(encoder_tmp_dirname) + pt_model.decoder.save_pretrained(decoder_tmp_dirname) + tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True + ) + # This is only for copying some specific attributes of this particular model. + tf_model.config = pt_model.config + + self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict) + + def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict): + + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + + # Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving + # the encoder/decoder models. + # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see + # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245 + # (the change in `src/transformers/modeling_tf_utils.py`) + _tf_model = TFEncoderDecoderModel(encoder_decoder_config) + # Make sure model is built + _tf_model(**inputs_dict) + + # Using `tf_model` to pass the test. + encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder) + decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder) + # Make sure models are built + encoder(encoder.dummy_inputs) + decoder(decoder.dummy_inputs) + tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) + + with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: + + tf_model.encoder.save_pretrained(encoder_tmp_dirname) + tf_model.decoder.save_pretrained(decoder_tmp_dirname) + pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True + ) + # This is only for copying some specific attributes of this particular model. + pt_model.config = tf_model.config + + self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict) + def test_encoder_decoder_model(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model(**input_ids_dict) @@ -341,6 +427,65 @@ class TFEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).") + + @is_pt_tf_cross_test + def test_pt_tf_equivalence(self): + + config_inputs_dict = self.prepare_config_and_inputs() + # Keep only common arguments + arg_names = [ + "config", + "input_ids", + "attention_mask", + "decoder_config", + "decoder_input_ids", + "decoder_attention_mask", + "encoder_hidden_states", + ] + config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names} + + config = config_inputs_dict.pop("config") + decoder_config = config_inputs_dict.pop("decoder_config") + + inputs_dict = config_inputs_dict + # `encoder_hidden_states` is not used in model call/forward + del inputs_dict["encoder_hidden_states"] + + # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) + batch_size = inputs_dict["decoder_attention_mask"].shape[0] + inputs_dict["decoder_attention_mask"] = tf.constant( + np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1) + ) + + # TF models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + decoder_config.use_cache = False + + self.assertTrue(decoder_config.cross_attention_hidden_size is None) + + # check without `enc_to_dec_proj` projection + self.assertTrue(config.hidden_size == decoder_config.hidden_size) + self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) + self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) + + # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`, + # which randomly initialize `enc_to_dec_proj`. + # # check `enc_to_dec_proj` work as expected + # decoder_config.hidden_size = decoder_config.hidden_size * 2 + # self.assertTrue(config.hidden_size != decoder_config.hidden_size) + # self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) + # self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) + + # Let's just check `enc_to_dec_proj` can run for now + decoder_config.hidden_size = decoder_config.hidden_size * 2 + self.assertTrue(config.hidden_size != decoder_config.hidden_size) + encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) + model = TFEncoderDecoderModel(encoder_decoder_config) + model(**inputs_dict) + @slow def test_real_model_save_load_from_pretrained(self): model_2 = self.get_pretrained_model()