add examples to doc (#4045)

This commit is contained in:
Patrick von Platen
2020-04-28 16:33:23 +02:00
committed by GitHub
parent fa49b9afea
commit 9a0a8c1c6f
2 changed files with 48 additions and 10 deletions

View File

@@ -27,18 +27,35 @@ class EncoderDecoderConfig(PretrainedConfig):
r""" r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a `EncoderDecoderModel`. :class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a `EncoderDecoderModel`.
It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs. It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` Configuration objects inherit from :class:`~transformers.PretrainedConfig`
and can be used to control the model outputs. and can be used to control the model outputs.
See the documentation for :class:`~transformers.PretrainedConfig` for more information. See the documentation for :class:`~transformers.PretrainedConfig` for more information.
Args:
kwargs (`optional`):
Remaining dictionary of keyword arguments. Notably:
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the encoder config.
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the decoder config.
Arguments: Example::
kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably:
encoder (:class:`PretrainedConfig`, optional, defaults to `None`): from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
An instance of a configuration object that defines the encoder config.
encoder (:class:`PretrainedConfig`, optional, defaults to `None`): # Initializing a BERT bert-base-uncased style configuration
An instance of a configuration object that defines the decoder config. config_encoder = BertConfig()
config_decoder = BertConfig()
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
# Initializing a Bert2Bert model from the bert-base-uncased style configurations
model = EncoderDecoderModel(config=config)
# Accessing the model configuration
config_encoder = model.config.encoder
config_decoder = model.config.decoder
""" """
model_type = "encoder_decoder" model_type = "encoder_decoder"

View File

@@ -125,6 +125,8 @@ class EncoderDecoderModel(PreTrainedModel):
Examples:: Examples::
from tranformers import EncoderDecoder
model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
""" """
@@ -230,6 +232,25 @@ class EncoderDecoderModel(PreTrainedModel):
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
Examples::
from transformers import EncoderDecoderModel, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
# forward
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
# training
loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, lm_labels=input_ids)[:2]
# generation
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
""" """
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}