From 4c81960b9bc0f553ddf800df16bb82804e162bcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 15 Oct 2019 17:53:38 +0200 Subject: [PATCH] comment the seq2seq functions --- transformers/modeling_seq2seq.py | 81 +++++++++++++++++++------------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 2154a4699d..b326f2bc1e 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -43,13 +43,21 @@ class PreTrainedSeq2seq(nn.Module): @classmethod def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs): - r""" Instantiates one of the base model classes of the library - from a pre-trained model configuration. - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) - To train the model, you should first set it back in training mode with `model.train()` + r""" Instantiates an encoder and a decoder from one or two base classes + of the library from pre-trained model checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you need to first set it back in training mode with `model.train()` Params: - pretrained_model_name_or_path: either: + encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. @@ -84,21 +92,17 @@ class PreTrainedSeq2seq(nn.Module): output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. - kwargs: (`optional`) Remaining dictionary of keyword arguments: + kwargs: (`optional`) Remaining dictionary of keyword arguments. Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + You can specify different kwargs for the decoder by prefixing the key with `decoder_` (e.g. ``decoder_output_attention=True``). + Examples:: - model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. - model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` - model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading - assert model.config.output_attention == True - # Loading from a TF checkpoint file instead of a PyTorch model (slower) - config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') - model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert """ # Separate the encoder- and decoder- specific kwargs. A kwarg is @@ -115,35 +119,49 @@ class PreTrainedSeq2seq(nn.Module): encoder = kwargs.pop('encoder_model', None) if encoder is None: kwargs_encoder['is_decoder'] = False - encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder) + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) decoder = kwargs.pop('decoder_model', None) if decoder is None: kwargs_decoder['is_decoder'] = True - decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) + decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) model = cls(encoder, decoder) return model def forward(self, *inputs, **kwargs): - # Extract decoder inputs - decoder_kwargs = {} - for key in kwargs.keys(): - if key.startswith('decoder_'): - decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key) + """ The forward pass on a seq2eq depends what we are performing: - # Compute encoder hidden states if needed - encoder_hidden_states = kwargs.pop('encoder_hidden_states', None) + - During training we perform one forward pass through both the encoder + and decoder; + - During prediction, we perform one forward pass through the encoder, + and then perform several forward passes with the encoder's hidden + state through the decoder to decode a full sequence. + + Therefore, we skip the forward pass on the encoder if an argument named + `encoder_hidden_state` is passed to this function. + + """ + # Separate the encoder- and decoder- specific kwargs. A kwarg is + # decoder-specific it the key starts with `decoder_` + kwargs_decoder = {} + kwargs_encoder = kwargs + for key in kwargs_encoder.keys(): + if key.startswith('decoder_'): + kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key) + + # Encode if needed (training, first prediction pass) + encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None) if encoder_hidden_states is None: - encoder_outputs = self.encoder(*inputs, *kwargs) + encoder_outputs = self.encoder(*inputs, **kwargs_encoder) encoder_hidden_states = encoder_outputs[0] else: - encoder_outputs = (,) + encoder_outputs = () # Decode - decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states - decoder_outputs = self.decoder(**decoder_kwargs) + kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs = self.decoder(**kwargs_decoder) return decoder_outputs + encoder_outputs @@ -161,11 +179,10 @@ class Model2LSTM(PreTrainedSeq2seq): # We will create a randomly initilized LSTM model as decoder if 'decoder_config' not in kwargs: raise ValueError("To load an LSTM in Seq2seq model, please supply either: " - " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or " - " - a dictionary of configuration parameters that will be used to initialize a - " torch.nn.LSTM model as `decoder_config` keyword argument. " - " E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`") - kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config')) + " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" + " - a dictionary of configuration parameters that will be used to initialize a" + " torch.nn.LSTM model as `decoder_config` keyword argument. " + " E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`") + kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config')) model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs) return model -