resolve PR comments
This commit is contained in:
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
|
||||
instantiated as a Seq2seq model with one of the base model classes of
|
||||
the library as encoder and (optionally) as decoder when created with
|
||||
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||
@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
|
||||
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
"""
|
||||
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as a whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("encoder_model", None)
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
kwargs_encoder["is_decoder"] = False
|
||||
encoder = AutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
encoder.config.is_decoder = False
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
kwargs_decoder["is_decoder"] = True
|
||||
decoder = AutoModelWithLMHead.from_pretrained(
|
||||
decoder_pretrained_model_name_or_path, **kwargs_decoder
|
||||
)
|
||||
decoder.config.is_decoder = True
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
|
||||
@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0][
|
||||
-1
|
||||
] # output of the encoder *stack*
|
||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
name of or that path to a pretrained model is specified the encoder and
|
||||
the decoder will be initialized with the pretrained weight (the
|
||||
cross-attention will be intialized randomly if its weights are not
|
||||
present).
|
||||
|
||||
It is possible to override this behavior and initialize, say, the decoder randomly
|
||||
by creating it beforehand as follows
|
||||
|
||||
config = BertConfig.from_pretrained()
|
||||
decoder = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model2Model, self).__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Some architectures require for the decoder to be initialized randomly
|
||||
# before fine-tuning.
|
||||
if kwargs.get("decoder_initialize_randomly", False):
|
||||
model.decoder.init_weights()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user