here's one big commit
This commit is contained in:
@@ -17,13 +17,12 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
self.decoder = decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
encoder_pretrained_model_name_or_path=None,
|
||||
decoder_pretrained_model_name_or_path=None,
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
|
||||
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_decoder = {}
|
||||
kwargs_encoder = kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
if key.startswith("decoder_"):
|
||||
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs.pop("encoder_model", None)
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("encoder_model", None)
|
||||
if encoder is None:
|
||||
kwargs_encoder["is_decoder"] = False
|
||||
encoder = AutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
|
||||
decoder = kwargs.pop("decoder_model", None)
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
kwargs_decoder["is_decoder"] = True
|
||||
decoder = AutoModelWithLMHead.from_pretrained(
|
||||
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a Seq2Seq model and its configuration file in a format
|
||||
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
|
||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
""" The forward pass on a seq2eq depends what we are performing:
|
||||
|
||||
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_decoder = {}
|
||||
kwargs_encoder = kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
if key.startswith("decoder_"):
|
||||
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
|
||||
encoder_hidden_states = encoder_outputs[0][
|
||||
-1
|
||||
] # output of the encoder *stack*
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
**kwargs)
|
||||
|
||||
if (
|
||||
"bert" not in pretrained_model_name_or_path
|
||||
or "roberta" in pretrained_model_name_or_path
|
||||
or "distilbert" in pretrained_model_name_or_path
|
||||
):
|
||||
raise ValueError("Only the Bert model is currently supported.")
|
||||
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Some architectures require for the decoder to be initialized randomly
|
||||
# before fine-tuning.
|
||||
if kwargs.get("decoder_initialize_randomly", False):
|
||||
model.decoder.init_weights()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user