comment the seq2seq functions
This commit is contained in:
@@ -43,13 +43,21 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, encoder_pretrained_model_name_or_path, decoder_pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
r""" Instantiates one of the base model classes of the library
|
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||||
from a pre-trained model configuration.
|
of the library from pre-trained model checkpoints.
|
||||||
|
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
To train the model, you should first set it back in training mode with `model.train()`
|
To train the model, you need to first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name_or_path: either:
|
encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
|
||||||
|
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
|
||||||
|
|
||||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
@@ -84,21 +92,17 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
output_loading_info: (`optional`) boolean:
|
output_loading_info: (`optional`) boolean:
|
||||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
kwargs: (`optional`) Remaining dictionary of keyword arguments.
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||||
|
|
||||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
You can specify different kwargs for the decoder by prefixing the key with `decoder_` (e.g. ``decoder_output_attention=True``).
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||||
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
|
||||||
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
|
||||||
assert model.config.output_attention == True
|
|
||||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
|
||||||
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
|
|
||||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||||
@@ -115,35 +119,49 @@ class PreTrainedSeq2seq(nn.Module):
|
|||||||
encoder = kwargs.pop('encoder_model', None)
|
encoder = kwargs.pop('encoder_model', None)
|
||||||
if encoder is None:
|
if encoder is None:
|
||||||
kwargs_encoder['is_decoder'] = False
|
kwargs_encoder['is_decoder'] = False
|
||||||
encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||||
|
|
||||||
decoder = kwargs.pop('decoder_model', None)
|
decoder = kwargs.pop('decoder_model', None)
|
||||||
if decoder is None:
|
if decoder is None:
|
||||||
kwargs_decoder['is_decoder'] = True
|
kwargs_decoder['is_decoder'] = True
|
||||||
decoder_model = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
|
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||||
|
|
||||||
model = cls(encoder, decoder)
|
model = cls(encoder, decoder)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, *inputs, **kwargs):
|
def forward(self, *inputs, **kwargs):
|
||||||
# Extract decoder inputs
|
""" The forward pass on a seq2eq depends what we are performing:
|
||||||
decoder_kwargs = {}
|
|
||||||
for key in kwargs.keys():
|
|
||||||
if key.startswith('decoder_'):
|
|
||||||
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
|
|
||||||
|
|
||||||
# Compute encoder hidden states if needed
|
- During training we perform one forward pass through both the encoder
|
||||||
encoder_hidden_states = kwargs.pop('encoder_hidden_states', None)
|
and decoder;
|
||||||
|
- During prediction, we perform one forward pass through the encoder,
|
||||||
|
and then perform several forward passes with the encoder's hidden
|
||||||
|
state through the decoder to decode a full sequence.
|
||||||
|
|
||||||
|
Therefore, we skip the forward pass on the encoder if an argument named
|
||||||
|
`encoder_hidden_state` is passed to this function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||||
|
# decoder-specific it the key starts with `decoder_`
|
||||||
|
kwargs_decoder = {}
|
||||||
|
kwargs_encoder = kwargs
|
||||||
|
for key in kwargs_encoder.keys():
|
||||||
|
if key.startswith('decoder_'):
|
||||||
|
kwargs_decoder[key.replace('decoder_', '')] = kwargs_encoder.pop(key)
|
||||||
|
|
||||||
|
# Encode if needed (training, first prediction pass)
|
||||||
|
encoder_hidden_states = kwargs_encoder.pop('encoder_hidden_states', None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(*inputs, *kwargs)
|
encoder_outputs = self.encoder(*inputs, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
else:
|
else:
|
||||||
encoder_outputs = (,)
|
encoder_outputs = ()
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
|
kwargs_decoder['encoder_hidden_states'] = encoder_hidden_states
|
||||||
decoder_outputs = self.decoder(**decoder_kwargs)
|
decoder_outputs = self.decoder(**kwargs_decoder)
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
@@ -162,10 +180,9 @@ class Model2LSTM(PreTrainedSeq2seq):
|
|||||||
if 'decoder_config' not in kwargs:
|
if 'decoder_config' not in kwargs:
|
||||||
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
|
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
|
||||||
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
|
||||||
" - a dictionary of configuration parameters that will be used to initialize a
|
" - a dictionary of configuration parameters that will be used to initialize a"
|
||||||
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
" torch.nn.LSTM model as `decoder_config` keyword argument. "
|
||||||
" E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`")
|
" E.g. `decoder_config={'input_size': 768, 'hidden_size': 768, 'num_layers': 2}`")
|
||||||
kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config'))
|
kwargs['decoder_model'] = torch.nn.LSTM(kwargs.pop('decoder_config'))
|
||||||
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user