From 56e2ee4eadc482a31ca46c97c3cc236824869510 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 17 Oct 2019 16:33:31 +0200 Subject: [PATCH] fix model2model --- transformers/modeling_seq2seq.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index cc5cc53bc3..ca3b9dc87a 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -28,7 +28,7 @@ from .modeling_utils import PreTrainedModel, SequenceSummary logger = logging.getLogger(__name__) -class PreTrainedSeq2seq(PreTrainedModel): +class PreTrainedSeq2seq(nn.Module): r""" :class:`~transformers.Seq2seq` is a generic model class that will be instantiated as a Seq2seq model with one of the base model classes of @@ -43,7 +43,7 @@ class PreTrainedSeq2seq(PreTrainedModel): self.decoder = decoder @classmethod - def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): + def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs): r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. @@ -177,8 +177,8 @@ class PreTrainedSeq2seq(PreTrainedModel): class Model2Model(PreTrainedSeq2seq): - def __init__(self): - super(Model2Model, self).__init__() + def __init__(self, *args, **kwargs): + super(Model2Model, self).__init__(*args, **kwargs) self.tie_weights() def tie_weights(self): @@ -197,7 +197,14 @@ class Model2Model(PreTrainedSeq2seq): by a model-specific keyword (bert, )... """ # self._tie_or_clone_weights(self.encoder, self.decoder) - raise NotImplementedError + pass + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path, + decoder_pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs) + return model class Model2LSTM(PreTrainedSeq2seq):