Use cross_attention_hidden_size in Encoder-Decoder models (#14378)
* add cross_attention_hidden_size to text-2-text encoder-decoder models (PT/Flax) * for TFEncoderDecoderModel * add equivalence test for TFEncoderDecoderModel * fix * fix failed equivalence tests * remove unused import * add detailed comment * Fix check_equivalence_tf_to_pt by using encoder/decoder * cleaning * Use cross_attention_hidden_size in speech-to-text * clean fast init logging msg in encoder decoder models * increase tol from 1e-5 to 1e-3 for tf test * style * style * make sure projection layer can run * remove type conversion + add check * fix conflict (config.output_hidden_size) * Remove TF -> PT in check_pt_tf_equivalence for TFEncoderDecoderModel Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
@@ -25,6 +26,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
||||
from ...modeling_outputs import Seq2SeqLMOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
|
||||
|
||||
@@ -181,13 +184,23 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder: Optional[PreTrainedModel] = None,
|
||||
decoder: Optional[PreTrainedModel] = None,
|
||||
):
|
||||
assert config is not None or (
|
||||
encoder is not None and decoder is not None
|
||||
), "Either a configuration or an Encoder and a decoder has to be provided"
|
||||
if config is None and (encoder is None or decoder is None):
|
||||
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
||||
if config is None:
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||
else:
|
||||
assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
|
||||
if not isinstance(config, self.config_class):
|
||||
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
|
||||
# initialize with config
|
||||
super().__init__(config)
|
||||
|
||||
@@ -218,9 +231,17 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
|
||||
assert (
|
||||
self.encoder.get_output_embeddings() is None
|
||||
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||
|
||||
if self.encoder.get_output_embeddings() is not None:
|
||||
raise ValueError(
|
||||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
# tie encoder, decoder weights if config set accordingly
|
||||
self.tie_weights()
|
||||
@@ -251,8 +272,12 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# At the moment fast initialization is not supported
|
||||
# for composite models
|
||||
# At the moment fast initialization is not supported for composite models
|
||||
if kwargs.get("_fast_init", False):
|
||||
logger.warning(
|
||||
"Fast initialization is currently not supported for EncoderDecoderModel. "
|
||||
"Falling back to slow initialization..."
|
||||
)
|
||||
kwargs["_fast_init"] = False
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
@@ -343,19 +368,18 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
assert (
|
||||
encoder_pretrained_model_name_or_path is not None
|
||||
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
@@ -366,18 +390,20 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
assert (
|
||||
decoder_pretrained_model_name_or_path is not None
|
||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
from ..auto.modeling_auto import AutoModelForCausalLM
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -386,7 +412,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
@@ -464,6 +494,13 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||
decoder_input_ids = shift_tokens_right(
|
||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||
|
||||
@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
|
||||
|
||||
@@ -227,9 +229,25 @@ class FlaxEncoderDecoderModule(nn.Module):
|
||||
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
||||
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
||||
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
self.enc_to_dec_proj = nn.Dense(
|
||||
self.decoder.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
self.enc_to_dec_proj = None
|
||||
|
||||
def _get_encoder_module(self):
|
||||
return self.encoder
|
||||
|
||||
def _get_projection_module(self):
|
||||
return self.enc_to_dec_proj
|
||||
|
||||
def _get_decoder_module(self):
|
||||
return self.decoder
|
||||
|
||||
@@ -256,11 +274,17 @@ class FlaxEncoderDecoderModule(nn.Module):
|
||||
deterministic=deterministic,
|
||||
)
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if self.enc_to_dec_proj is not None:
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
position_ids=decoder_position_ids,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -305,6 +329,15 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, 1))
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@@ -537,12 +570,22 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
||||
def _decoder_forward(
|
||||
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
||||
):
|
||||
|
||||
projection_module = module._get_projection_module()
|
||||
decoder_module = module._get_decoder_module()
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if projection_module is not None:
|
||||
encoder_hidden_states = projection_module(encoder_hidden_states)
|
||||
|
||||
return decoder_module(
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_position_ids,
|
||||
encoder_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -772,19 +815,18 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
assert (
|
||||
encoder_pretrained_model_name_or_path is not None
|
||||
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModel
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
@@ -797,18 +839,20 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
assert (
|
||||
decoder_pretrained_model_name_or_path is not None
|
||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -817,7 +861,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
@@ -23,12 +23,13 @@ import tensorflow as tf
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, input_processing
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||
@@ -168,12 +169,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
decoder: Optional[TFPreTrainedModel] = None,
|
||||
):
|
||||
if config is None and (encoder is None or decoder is None):
|
||||
raise ValueError("Either a configuration or an encoder and a decoder has to be provided")
|
||||
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
||||
if config is None:
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||
else:
|
||||
if not isinstance(config, self.config_class):
|
||||
raise ValueError(f"config: {config} has to be of type {self.config_class}")
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
|
||||
# initialize with config
|
||||
super().__init__(config)
|
||||
|
||||
@@ -200,8 +211,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
self.enc_to_dec_proj = tf.keras.layers.Dense(
|
||||
units=self.decoder.config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.encoder.initializer_range),
|
||||
name="enc_to_dec_proj",
|
||||
)
|
||||
|
||||
if self.encoder.get_output_embeddings() is not None:
|
||||
raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head")
|
||||
raise ValueError(
|
||||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@@ -355,16 +379,16 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
if encoder is None:
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
@@ -387,15 +411,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -404,7 +431,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
kwargs_decoder["name"] = "decoder"
|
||||
@@ -485,6 +516,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Let the user be responsible for the expected format.
|
||||
if encoder_outputs is not None:
|
||||
if return_dict and not isinstance(encoder_outputs, ModelOutput):
|
||||
raise ValueError(
|
||||
"If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
|
||||
f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
|
||||
)
|
||||
|
||||
if encoder_outputs is None:
|
||||
|
||||
encoder_processing_inputs = {
|
||||
@@ -518,6 +557,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
decoder_processing_inputs = {
|
||||
"func": self.decoder.call,
|
||||
"config": self.decoder.config,
|
||||
@@ -562,14 +608,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
||||
output = tuple([x for x in output if x is not None])
|
||||
return output
|
||||
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
||||
encoder_outputs = TFBaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
|
||||
@@ -195,6 +195,15 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
if not isinstance(config, self.config_class):
|
||||
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
||||
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
|
||||
# initialize with config
|
||||
# make sure input & output embeddings is not tied
|
||||
config.tie_word_embeddings = False
|
||||
@@ -225,7 +234,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
# get encoder output hidden size
|
||||
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
|
||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
||||
if (
|
||||
self.encoder_output_dim != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||
|
||||
@@ -248,11 +260,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# At the moment fast initialization is not supported
|
||||
# for composite models
|
||||
# At the moment fast initialization is not supported for composite models
|
||||
if kwargs.get("_fast_init", False):
|
||||
logger.warning(
|
||||
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. Falling back to slow intialization..."
|
||||
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. "
|
||||
"Falling back to slow initialization..."
|
||||
)
|
||||
kwargs["_fast_init"] = False
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
@@ -346,13 +358,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
if encoder is None:
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. In this case make sure that `encoder_pretrained_model_name_or_path` defined"
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
@@ -368,7 +380,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
@@ -376,8 +389,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -389,7 +403,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
@@ -472,8 +487,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# project encoder_hidden_states
|
||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
||||
# optionally project encoder_hidden_states
|
||||
if (
|
||||
self.encoder_output_dim != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
# compute correct encoder attention mask
|
||||
|
||||
@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
|
||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
@@ -301,8 +303,8 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
f"it has to be equal to the encoder's `hidden_size`."
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
@@ -781,19 +783,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined"
|
||||
"to be defined."
|
||||
)
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModel
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder "
|
||||
"model. Cross-attention and casual mask are disabled."
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
@@ -811,17 +809,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention "
|
||||
f"layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if "
|
||||
f"{decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -830,11 +826,11 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order "
|
||||
f"to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the "
|
||||
"attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to "
|
||||
"`.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to "
|
||||
"`.from_encoder_decoder_pretrained(...)`"
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
@@ -178,8 +178,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
if config.decoder.cross_attention_hidden_size is not None:
|
||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||
raise ValueError(
|
||||
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
f"it has to be equal to the encoder's `hidden_size`."
|
||||
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||
"it has to be equal to the encoder's `hidden_size`. "
|
||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||
)
|
||||
@@ -241,7 +241,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
# At the moment fast initialization is not supported for composite models
|
||||
if kwargs.get("_fast_init", False):
|
||||
logger.warning(
|
||||
"Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..."
|
||||
"Fast initialization is currently not supported for VisionEncoderDecoderModel. "
|
||||
"Falling back to slow initialization..."
|
||||
)
|
||||
kwargs["_fast_init"] = False
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
@@ -334,14 +335,13 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
if encoder is None:
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. "
|
||||
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined"
|
||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
@@ -357,16 +357,18 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||
"to be defined."
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model."
|
||||
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||
"cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
@@ -375,11 +377,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder."
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config`"
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` "
|
||||
f"to `.from_encoder_decoder_pretrained(...)`"
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
@@ -19,8 +19,8 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
|
||||
from .test_modeling_flax_bert import FlaxBertModelTester
|
||||
from .test_modeling_flax_common import ids_tensor
|
||||
@@ -35,6 +35,15 @@ if is_flax_available():
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import EncoderDecoderModel
|
||||
|
||||
|
||||
@require_flax
|
||||
@@ -234,6 +243,71 @@ class FlaxEncoderDecoderMixin:
|
||||
generated_sequences = generated_output.sequences
|
||||
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
@@ -258,6 +332,44 @@ class FlaxEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||
)
|
||||
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# check `enc_to_dec_proj` work as expected
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
||||
@@ -31,6 +31,8 @@ from .test_modeling_tf_roberta import TFRobertaModelTester
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@@ -309,6 +311,90 @@ class TFEncoderDecoderMixin:
|
||||
)
|
||||
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
tf_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
# PT -> TF
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model_loaded.config = pt_model.config
|
||||
|
||||
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||
|
||||
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
tf_model.config = pt_model.config
|
||||
|
||||
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||
|
||||
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
|
||||
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
|
||||
# the encoder/decoder models.
|
||||
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
|
||||
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
|
||||
# (the change in `src/transformers/modeling_tf_utils.py`)
|
||||
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
# Make sure model is built
|
||||
_tf_model(**inputs_dict)
|
||||
|
||||
# Using `tf_model` to pass the test.
|
||||
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
|
||||
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
|
||||
# Make sure models are built
|
||||
encoder(encoder.dummy_inputs)
|
||||
decoder(decoder.dummy_inputs)
|
||||
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
|
||||
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
|
||||
)
|
||||
# This is only for copying some specific attributes of this particular model.
|
||||
pt_model.config = tf_model.config
|
||||
|
||||
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||
|
||||
def test_encoder_decoder_model(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model(**input_ids_dict)
|
||||
@@ -341,6 +427,65 @@ class TFEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
# Keep only common arguments
|
||||
arg_names = [
|
||||
"config",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_config",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"encoder_hidden_states",
|
||||
]
|
||||
config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names}
|
||||
|
||||
config = config_inputs_dict.pop("config")
|
||||
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
# `encoder_hidden_states` is not used in model call/forward
|
||||
del inputs_dict["encoder_hidden_states"]
|
||||
|
||||
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
|
||||
)
|
||||
|
||||
# TF models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
decoder_config.use_cache = False
|
||||
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||
# which randomly initialize `enc_to_dec_proj`.
|
||||
# # check `enc_to_dec_proj` work as expected
|
||||
# decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||
|
||||
# Let's just check `enc_to_dec_proj` can run for now
|
||||
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||
model(**inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
||||
Reference in New Issue
Block a user