Use cross_attention_hidden_size in Encoder-Decoder models (#14378)
* add cross_attention_hidden_size to text-2-text encoder-decoder models (PT/Flax) * for TFEncoderDecoderModel * add equivalence test for TFEncoderDecoderModel * fix * fix failed equivalence tests * remove unused import * add detailed comment * Fix check_equivalence_tf_to_pt by using encoder/decoder * cleaning * Use cross_attention_hidden_size in speech-to-text * clean fast init logging msg in encoder decoder models * increase tol from 1e-5 to 1e-3 for tf test * style * style * make sure projection layer can run * remove type conversion + add check * fix conflict (config.output_hidden_size) * Remove TF -> PT in check_pt_tf_equivalence for TFEncoderDecoderModel Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import warnings
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
@@ -25,6 +26,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
|||||||
from ...modeling_outputs import Seq2SeqLMOutput
|
from ...modeling_outputs import Seq2SeqLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -181,13 +184,23 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder: Optional[PreTrainedModel] = None,
|
encoder: Optional[PreTrainedModel] = None,
|
||||||
decoder: Optional[PreTrainedModel] = None,
|
decoder: Optional[PreTrainedModel] = None,
|
||||||
):
|
):
|
||||||
assert config is not None or (
|
if config is None and (encoder is None or decoder is None):
|
||||||
encoder is not None and decoder is not None
|
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
||||||
), "Either a configuration or an Encoder and a decoder has to be provided"
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||||
else:
|
else:
|
||||||
assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
|
if not isinstance(config, self.config_class):
|
||||||
|
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
||||||
|
|
||||||
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
|
)
|
||||||
|
|
||||||
# initialize with config
|
# initialize with config
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -218,9 +231,17 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
self.encoder.config = self.config.encoder
|
self.encoder.config = self.config.encoder
|
||||||
self.decoder.config = self.config.decoder
|
self.decoder.config = self.config.decoder
|
||||||
|
|
||||||
assert (
|
# encoder outputs might need to be projected to different dimension for decoder
|
||||||
self.encoder.get_output_embeddings() is None
|
if (
|
||||||
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
|
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||||
|
|
||||||
|
if self.encoder.get_output_embeddings() is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||||
|
)
|
||||||
|
|
||||||
# tie encoder, decoder weights if config set accordingly
|
# tie encoder, decoder weights if config set accordingly
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
@@ -251,8 +272,12 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
# At the moment fast initialization is not supported
|
# At the moment fast initialization is not supported for composite models
|
||||||
# for composite models
|
if kwargs.get("_fast_init", False):
|
||||||
|
logger.warning(
|
||||||
|
"Fast initialization is currently not supported for EncoderDecoderModel. "
|
||||||
|
"Falling back to slow initialization..."
|
||||||
|
)
|
||||||
kwargs["_fast_init"] = False
|
kwargs["_fast_init"] = False
|
||||||
return super().from_pretrained(*args, **kwargs)
|
return super().from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
@@ -343,19 +368,18 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||||
encoder = kwargs_encoder.pop("model", None)
|
encoder = kwargs_encoder.pop("model", None)
|
||||||
if encoder is None:
|
if encoder is None:
|
||||||
assert (
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
encoder_pretrained_model_name_or_path is not None
|
raise ValueError(
|
||||||
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
from ..auto.modeling_auto import AutoModel
|
"to be defined."
|
||||||
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
)
|
)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
encoder_config.add_cross_attention = False
|
encoder_config.add_cross_attention = False
|
||||||
@@ -366,18 +390,20 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
decoder = kwargs_decoder.pop("model", None)
|
decoder = kwargs_decoder.pop("model", None)
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
assert (
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
decoder_pretrained_model_name_or_path is not None
|
raise ValueError(
|
||||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
from ..auto.modeling_auto import AutoModelForCausalLM
|
"to be defined."
|
||||||
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -386,7 +412,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
@@ -464,6 +494,13 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
# optionally project encoder_hidden_states
|
||||||
|
if (
|
||||||
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
|
||||||
decoder_input_ids = shift_tokens_right(
|
decoder_input_ids = shift_tokens_right(
|
||||||
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
|||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel
|
from ...modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
|
||||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -227,9 +229,25 @@ class FlaxEncoderDecoderModule(nn.Module):
|
|||||||
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
|
||||||
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
|
||||||
|
|
||||||
|
# encoder outputs might need to be projected to different dimension for decoder
|
||||||
|
if (
|
||||||
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
|
self.enc_to_dec_proj = nn.Dense(
|
||||||
|
self.decoder.config.hidden_size,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.enc_to_dec_proj = None
|
||||||
|
|
||||||
def _get_encoder_module(self):
|
def _get_encoder_module(self):
|
||||||
return self.encoder
|
return self.encoder
|
||||||
|
|
||||||
|
def _get_projection_module(self):
|
||||||
|
return self.enc_to_dec_proj
|
||||||
|
|
||||||
def _get_decoder_module(self):
|
def _get_decoder_module(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
@@ -256,11 +274,17 @@ class FlaxEncoderDecoderModule(nn.Module):
|
|||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
# optionally project encoder_hidden_states
|
||||||
|
if self.enc_to_dec_proj is not None:
|
||||||
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
position_ids=decoder_position_ids,
|
position_ids=decoder_position_ids,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -305,6 +329,15 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = ((1, 1), (1, 1))
|
input_shape = ((1, 1), (1, 1))
|
||||||
|
|
||||||
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
|
)
|
||||||
|
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@@ -537,12 +570,22 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
mutable = False
|
mutable = False
|
||||||
|
|
||||||
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
|
def _decoder_forward(
|
||||||
|
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
projection_module = module._get_projection_module()
|
||||||
decoder_module = module._get_decoder_module()
|
decoder_module = module._get_decoder_module()
|
||||||
|
|
||||||
|
# optionally project encoder_hidden_states
|
||||||
|
if projection_module is not None:
|
||||||
|
encoder_hidden_states = projection_module(encoder_hidden_states)
|
||||||
|
|
||||||
return decoder_module(
|
return decoder_module(
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
|
encoder_hidden_states,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -772,19 +815,18 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||||
encoder = kwargs_encoder.pop("model", None)
|
encoder = kwargs_encoder.pop("model", None)
|
||||||
if encoder is None:
|
if encoder is None:
|
||||||
assert (
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
encoder_pretrained_model_name_or_path is not None
|
raise ValueError(
|
||||||
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
from ..auto.modeling_flax_auto import FlaxAutoModel
|
"to be defined."
|
||||||
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
)
|
)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
encoder_config.add_cross_attention = False
|
encoder_config.add_cross_attention = False
|
||||||
@@ -797,18 +839,20 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
decoder = kwargs_decoder.pop("model", None)
|
decoder = kwargs_decoder.pop("model", None)
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
assert (
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
decoder_pretrained_model_name_or_path is not None
|
raise ValueError(
|
||||||
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
|
"to be defined."
|
||||||
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -817,7 +861,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|||||||
@@ -23,12 +23,13 @@ import tensorflow as tf
|
|||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, input_processing
|
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
|
||||||
@@ -168,12 +169,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
decoder: Optional[TFPreTrainedModel] = None,
|
decoder: Optional[TFPreTrainedModel] = None,
|
||||||
):
|
):
|
||||||
if config is None and (encoder is None or decoder is None):
|
if config is None and (encoder is None or decoder is None):
|
||||||
raise ValueError("Either a configuration or an encoder and a decoder has to be provided")
|
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
||||||
if config is None:
|
if config is None:
|
||||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||||
else:
|
else:
|
||||||
if not isinstance(config, self.config_class):
|
if not isinstance(config, self.config_class):
|
||||||
raise ValueError(f"config: {config} has to be of type {self.config_class}")
|
raise ValueError(f"config: {config} has to be of type {self.config_class}")
|
||||||
|
|
||||||
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
|
)
|
||||||
|
|
||||||
# initialize with config
|
# initialize with config
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -200,8 +211,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
self.encoder.config = self.config.encoder
|
self.encoder.config = self.config.encoder
|
||||||
self.decoder.config = self.config.decoder
|
self.decoder.config = self.config.decoder
|
||||||
|
|
||||||
|
# encoder outputs might need to be projected to different dimension for decoder
|
||||||
|
if (
|
||||||
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
|
self.enc_to_dec_proj = tf.keras.layers.Dense(
|
||||||
|
units=self.decoder.config.hidden_size,
|
||||||
|
kernel_initializer=get_initializer(config.encoder.initializer_range),
|
||||||
|
name="enc_to_dec_proj",
|
||||||
|
)
|
||||||
|
|
||||||
if self.encoder.get_output_embeddings() is not None:
|
if self.encoder.get_output_embeddings() is not None:
|
||||||
raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head")
|
raise ValueError(
|
||||||
|
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
@@ -355,16 +379,16 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
if encoder is None:
|
if encoder is None:
|
||||||
if encoder_pretrained_model_name_or_path is None:
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
|
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
)
|
)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
encoder_config.add_cross_attention = False
|
encoder_config.add_cross_attention = False
|
||||||
@@ -387,15 +411,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
if decoder is None:
|
if decoder is None:
|
||||||
if decoder_pretrained_model_name_or_path is None:
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
|
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -404,7 +431,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
|
|
||||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs_decoder["name"] = "decoder"
|
kwargs_decoder["name"] = "decoder"
|
||||||
@@ -485,6 +516,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Let the user be responsible for the expected format.
|
||||||
|
if encoder_outputs is not None:
|
||||||
|
if return_dict and not isinstance(encoder_outputs, ModelOutput):
|
||||||
|
raise ValueError(
|
||||||
|
"If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
|
||||||
|
f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
|
|
||||||
encoder_processing_inputs = {
|
encoder_processing_inputs = {
|
||||||
@@ -518,6 +557,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
# optionally project encoder_hidden_states
|
||||||
|
if (
|
||||||
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
decoder_processing_inputs = {
|
decoder_processing_inputs = {
|
||||||
"func": self.decoder.call,
|
"func": self.decoder.call,
|
||||||
"config": self.decoder.config,
|
"config": self.decoder.config,
|
||||||
@@ -562,14 +608,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
|
|||||||
output = tuple([x for x in output if x is not None])
|
output = tuple([x for x in output if x is not None])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
|
||||||
if not isinstance(encoder_outputs, TFBaseModelOutput):
|
|
||||||
encoder_outputs = TFBaseModelOutput(
|
|
||||||
last_hidden_state=encoder_outputs[0],
|
|
||||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TFSeq2SeqLMOutput(
|
return TFSeq2SeqLMOutput(
|
||||||
loss=decoder_outputs.loss,
|
loss=decoder_outputs.loss,
|
||||||
logits=decoder_outputs.logits,
|
logits=decoder_outputs.logits,
|
||||||
|
|||||||
@@ -195,6 +195,15 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
if not isinstance(config, self.config_class):
|
if not isinstance(config, self.config_class):
|
||||||
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
||||||
|
|
||||||
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
|
)
|
||||||
|
|
||||||
# initialize with config
|
# initialize with config
|
||||||
# make sure input & output embeddings is not tied
|
# make sure input & output embeddings is not tied
|
||||||
config.tie_word_embeddings = False
|
config.tie_word_embeddings = False
|
||||||
@@ -225,7 +234,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
# get encoder output hidden size
|
# get encoder output hidden size
|
||||||
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
|
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
|
||||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
if (
|
||||||
|
self.encoder_output_dim != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
# encoder outputs might need to be projected to different dimension for decoder
|
# encoder outputs might need to be projected to different dimension for decoder
|
||||||
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||||
|
|
||||||
@@ -248,11 +260,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
# At the moment fast initialization is not supported
|
# At the moment fast initialization is not supported for composite models
|
||||||
# for composite models
|
|
||||||
if kwargs.get("_fast_init", False):
|
if kwargs.get("_fast_init", False):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. Falling back to slow intialization..."
|
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. "
|
||||||
|
"Falling back to slow initialization..."
|
||||||
)
|
)
|
||||||
kwargs["_fast_init"] = False
|
kwargs["_fast_init"] = False
|
||||||
return super().from_pretrained(*args, **kwargs)
|
return super().from_pretrained(*args, **kwargs)
|
||||||
@@ -346,13 +358,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
if encoder is None:
|
if encoder is None:
|
||||||
if encoder_pretrained_model_name_or_path is None:
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. In this case make sure that `encoder_pretrained_model_name_or_path` defined"
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
@@ -368,7 +380,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
if decoder is None:
|
if decoder is None:
|
||||||
if decoder_pretrained_model_name_or_path is None:
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
@@ -376,8 +389,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -389,7 +403,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
@@ -472,8 +487,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
# project encoder_hidden_states
|
# optionally project encoder_hidden_states
|
||||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
if (
|
||||||
|
self.encoder_output_dim != self.decoder.config.hidden_size
|
||||||
|
and self.decoder.config.cross_attention_hidden_size is None
|
||||||
|
):
|
||||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
# compute correct encoder attention mask
|
# compute correct encoder attention mask
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
|
|||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel
|
from ...modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
|
||||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -301,8 +303,8 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
if config.decoder.cross_attention_hidden_size is not None:
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
f"it has to be equal to the encoder's `hidden_size`."
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
)
|
)
|
||||||
@@ -781,19 +783,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
if encoder_pretrained_model_name_or_path is None:
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
"to be defined"
|
"to be defined."
|
||||||
)
|
)
|
||||||
from ..auto.modeling_flax_auto import FlaxAutoModel
|
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
"model. Cross-attention and casual mask are disabled."
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
)
|
)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
encoder_config.add_cross_attention = False
|
encoder_config.add_cross_attention = False
|
||||||
@@ -811,17 +809,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
"to be defined."
|
"to be defined."
|
||||||
)
|
)
|
||||||
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
|
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
from ..auto.configuration_auto import AutoConfig
|
|
||||||
|
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
f"layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if "
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
f"{decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -830,11 +826,11 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order "
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
f"to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the "
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
"attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to "
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
"`.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to "
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
"`.from_encoder_decoder_pretrained(...)`"
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|||||||
@@ -178,8 +178,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
if config.decoder.cross_attention_hidden_size is not None:
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
|
||||||
f"it has to be equal to the encoder's `hidden_size`."
|
"it has to be equal to the encoder's `hidden_size`. "
|
||||||
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
|
||||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
)
|
)
|
||||||
@@ -241,7 +241,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
# At the moment fast initialization is not supported for composite models
|
# At the moment fast initialization is not supported for composite models
|
||||||
if kwargs.get("_fast_init", False):
|
if kwargs.get("_fast_init", False):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..."
|
"Fast initialization is currently not supported for VisionEncoderDecoderModel. "
|
||||||
|
"Falling back to slow initialization..."
|
||||||
)
|
)
|
||||||
kwargs["_fast_init"] = False
|
kwargs["_fast_init"] = False
|
||||||
return super().from_pretrained(*args, **kwargs)
|
return super().from_pretrained(*args, **kwargs)
|
||||||
@@ -334,14 +335,13 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
if encoder is None:
|
if encoder is None:
|
||||||
if encoder_pretrained_model_name_or_path is None:
|
if encoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. "
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
|
||||||
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined"
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||||
@@ -357,7 +357,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
if decoder is None:
|
if decoder is None:
|
||||||
if decoder_pretrained_model_name_or_path is None:
|
if decoder_pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
|
||||||
|
"to be defined."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
@@ -365,8 +366,9 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||||
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
|
||||||
|
"cross attention layers."
|
||||||
)
|
)
|
||||||
decoder_config.is_decoder = True
|
decoder_config.is_decoder = True
|
||||||
decoder_config.add_cross_attention = True
|
decoder_config.add_cross_attention = True
|
||||||
@@ -378,8 +380,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
|
||||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
|
||||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` "
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
|
||||||
f"to `.from_encoder_decoder_pretrained(...)`"
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import is_flax_available
|
from transformers import is_flax_available, is_torch_available
|
||||||
from transformers.testing_utils import require_flax, slow
|
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||||
|
|
||||||
from .test_modeling_flax_bert import FlaxBertModelTester
|
from .test_modeling_flax_bert import FlaxBertModelTester
|
||||||
from .test_modeling_flax_common import ids_tensor
|
from .test_modeling_flax_common import ids_tensor
|
||||||
@@ -35,6 +35,15 @@ if is_flax_available():
|
|||||||
FlaxEncoderDecoderModel,
|
FlaxEncoderDecoderModel,
|
||||||
FlaxGPT2LMHeadModel,
|
FlaxGPT2LMHeadModel,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
|
convert_pytorch_state_dict_to_flax,
|
||||||
|
load_flax_weights_in_pytorch_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import EncoderDecoderModel
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@@ -234,6 +243,71 @@ class FlaxEncoderDecoderMixin:
|
|||||||
generated_sequences = generated_output.sequences
|
generated_sequences = generated_output.sequences
|
||||||
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||||
|
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
pt_model.eval()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
flax_inputs = inputs_dict
|
||||||
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||||
|
|
||||||
|
# PT -> Flax
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||||
|
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||||
|
|
||||||
|
# Flax -> PT
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fx_model.save_pretrained(tmpdirname)
|
||||||
|
pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
|
pt_model_loaded.to(torch_device)
|
||||||
|
pt_model_loaded.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||||
|
|
||||||
|
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
|
||||||
|
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||||
|
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
|
||||||
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
|
fx_model.params = fx_state
|
||||||
|
|
||||||
|
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||||
|
|
||||||
|
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
|
||||||
|
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||||
|
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||||
@@ -258,6 +332,44 @@ class FlaxEncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
|
diff = np.abs((a - b)).max()
|
||||||
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_pt_flax_equivalence(self):
|
||||||
|
|
||||||
|
config_inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
config = config_inputs_dict.pop("config")
|
||||||
|
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||||
|
|
||||||
|
inputs_dict = config_inputs_dict
|
||||||
|
# `encoder_hidden_states` is not used in model call/forward
|
||||||
|
del inputs_dict["encoder_hidden_states"]
|
||||||
|
|
||||||
|
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||||
|
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||||
|
inputs_dict["decoder_attention_mask"] = np.concatenate(
|
||||||
|
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
|
||||||
|
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||||
|
|
||||||
|
# check without `enc_to_dec_proj` projection
|
||||||
|
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||||
|
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||||
|
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# check `enc_to_dec_proj` work as expected
|
||||||
|
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||||
|
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||||
|
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||||
|
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_real_model_save_load_from_pretrained(self):
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
model_2 = self.get_pretrained_model()
|
model_2 = self.get_pretrained_model()
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from .test_modeling_tf_roberta import TFRobertaModelTester
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -309,6 +311,90 @@ class TFEncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
|
||||||
|
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
pt_model.eval()
|
||||||
|
|
||||||
|
# prepare inputs
|
||||||
|
tf_inputs = inputs_dict
|
||||||
|
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
tf_outputs = tf_model(**inputs_dict).to_tuple()
|
||||||
|
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
|
# PT -> TF
|
||||||
|
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||||
|
|
||||||
|
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||||
|
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||||
|
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
)
|
||||||
|
# This is only for copying some specific attributes of this particular model.
|
||||||
|
tf_model_loaded.config = pt_model.config
|
||||||
|
|
||||||
|
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
|
||||||
|
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||||
|
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
|
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
|
||||||
|
pt_model = EncoderDecoderModel(encoder_decoder_config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||||
|
|
||||||
|
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||||
|
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||||
|
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
|
||||||
|
)
|
||||||
|
# This is only for copying some specific attributes of this particular model.
|
||||||
|
tf_model.config = pt_model.config
|
||||||
|
|
||||||
|
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||||
|
|
||||||
|
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
|
||||||
|
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
|
||||||
|
# the encoder/decoder models.
|
||||||
|
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
|
||||||
|
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
|
||||||
|
# (the change in `src/transformers/modeling_tf_utils.py`)
|
||||||
|
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
# Make sure model is built
|
||||||
|
_tf_model(**inputs_dict)
|
||||||
|
|
||||||
|
# Using `tf_model` to pass the test.
|
||||||
|
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
|
||||||
|
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
|
||||||
|
# Make sure models are built
|
||||||
|
encoder(encoder.dummy_inputs)
|
||||||
|
decoder(decoder.dummy_inputs)
|
||||||
|
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||||
|
|
||||||
|
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||||
|
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||||
|
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
|
||||||
|
)
|
||||||
|
# This is only for copying some specific attributes of this particular model.
|
||||||
|
pt_model.config = tf_model.config
|
||||||
|
|
||||||
|
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model(self):
|
def test_encoder_decoder_model(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model(**input_ids_dict)
|
self.check_encoder_decoder_model(**input_ids_dict)
|
||||||
@@ -341,6 +427,65 @@ class TFEncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
|
diff = np.abs((a - b)).max()
|
||||||
|
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_pt_tf_equivalence(self):
|
||||||
|
|
||||||
|
config_inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
# Keep only common arguments
|
||||||
|
arg_names = [
|
||||||
|
"config",
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_config",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
"encoder_hidden_states",
|
||||||
|
]
|
||||||
|
config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names}
|
||||||
|
|
||||||
|
config = config_inputs_dict.pop("config")
|
||||||
|
decoder_config = config_inputs_dict.pop("decoder_config")
|
||||||
|
|
||||||
|
inputs_dict = config_inputs_dict
|
||||||
|
# `encoder_hidden_states` is not used in model call/forward
|
||||||
|
del inputs_dict["encoder_hidden_states"]
|
||||||
|
|
||||||
|
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
|
||||||
|
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
|
||||||
|
inputs_dict["decoder_attention_mask"] = tf.constant(
|
||||||
|
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TF models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
|
||||||
|
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||||
|
|
||||||
|
# check without `enc_to_dec_proj` projection
|
||||||
|
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||||
|
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||||
|
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
|
||||||
|
# which randomly initialize `enc_to_dec_proj`.
|
||||||
|
# # check `enc_to_dec_proj` work as expected
|
||||||
|
# decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||||
|
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||||
|
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
|
||||||
|
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# Let's just check `enc_to_dec_proj` can run for now
|
||||||
|
decoder_config.hidden_size = decoder_config.hidden_size * 2
|
||||||
|
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
|
||||||
|
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
|
model = TFEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
model(**inputs_dict)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_real_model_save_load_from_pretrained(self):
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
model_2 = self.get_pretrained_model()
|
model_2 = self.get_pretrained_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user