From 6d6c32673726896d682f71a40476576972d127b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 15 Oct 2019 16:07:07 +0200 Subject: [PATCH] take path to pretrained for encoder and decoder for init --- transformers/modeling_seq2seq.py | 61 ++++++++++++++------------------ 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 466a101f47..2154a4699d 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -21,21 +21,20 @@ import logging import torch from torch import nn -from .modeling_auto import AutoModel, AutoModelWithLMHead - -from .modeling_utils import PreTrainedModel, SequenceSummary - from .file_utils import add_start_docstrings +from .modeling_auto import AutoModel, AutoModelWithLMHead +from .modeling_utils import PreTrainedModel, SequenceSummary logger = logging.getLogger(__name__) class PreTrainedSeq2seq(nn.Module): r""" - :class:`~transformers.Seq2seq` is a generic model class - that will be instantiated as a Seq2seq model with one of the base model classes of the library - as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` - class method. + :class:`~transformers.Seq2seq` is a generic model class that will be + instantiated as a Seq2seq model with one of the base model classes of + the library as encoder and (optionally) as decoder when created with + the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class + method. """ def __init__(self, encoder, decoder): super(PreTrainedSeq2seq, self).__init__() @@ -43,7 +42,7 @@ class PreTrainedSeq2seq(nn.Module): self.decoder = decoder @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): r""" Instantiates one of the base model classes of the library from a pre-trained model configuration. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) @@ -100,40 +99,34 @@ class PreTrainedSeq2seq(nn.Module): # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - # Extract encoder and decoder model if provided - encoder_model = kwargs.pop('encoder_model', None) - decoder_model = kwargs.pop('decoder_model', None) - # Extract decoder kwargs so we only have encoder kwargs for now - if decoder_model is None: - decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path) - decoder_kwargs = {} - for key in kwargs.keys(): + # Separate the encoder- and decoder- specific kwargs. A kwarg is + # decoder-specific it the key starts with `decoder_` + kwargs_decoder = {} + kwargs_encoder = kwargs + for key in kwargs_encoder.keys(): if key.startswith('decoder_'): - decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) + kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) - # Load and initialize the decoder - if encoder_model: - encoder = encoder_model - else: - # Load and initialize the encoder - kwargs['is_decoder'] = False # Make sure the encoder will be an encoder - encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs.pop('encoder_model', None) + if encoder is None: + kwargs_encoder['is_decoder'] = False + encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder) - # Load and initialize the decoder - if decoder_model: - decoder = decoder_model - else: - kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... - kwargs['is_decoder'] = True # Make sure the decoder will be a decoder - decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + decoder = kwargs.pop('decoder_model', None) + if decoder is None: + kwargs_decoder['is_decoder'] = True + decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) model = cls(encoder, decoder) + return model - def forward(self, *inputs, *kwargs): + def forward(self, *inputs, **kwargs): # Extract decoder inputs decoder_kwargs = {} for key in kwargs.keys():