resolve PR comments

This commit is contained in:
Rémi Louf
2019-10-29 17:10:20 +01:00
parent 4c3ac4a7d8
commit dfce409691
7 changed files with 647 additions and 473 deletions

View File

@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("encoder_model", None)
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
encoder.config.is_decoder = False
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder
)
decoder.config.is_decoder = True
model = cls(encoder, decoder)
@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0][
-1
] # output of the encoder *stack*
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights()
@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
model = super(Model2Model, cls).from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if kwargs.get("decoder_initialize_randomly", False):
model.decoder.init_weights()
return model