add examples to doc (#4045)
This commit is contained in:
committed by
GitHub
parent
fa49b9afea
commit
9a0a8c1c6f
@@ -32,13 +32,30 @@ class EncoderDecoderConfig(PretrainedConfig):
|
|||||||
and can be used to control the model outputs.
|
and can be used to control the model outputs.
|
||||||
See the documentation for :class:`~transformers.PretrainedConfig` for more information.
|
See the documentation for :class:`~transformers.PretrainedConfig` for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
Arguments:
|
kwargs (`optional`):
|
||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably:
|
Remaining dictionary of keyword arguments. Notably:
|
||||||
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
|
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
|
||||||
An instance of a configuration object that defines the encoder config.
|
An instance of a configuration object that defines the encoder config.
|
||||||
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
|
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
|
||||||
An instance of a configuration object that defines the decoder config.
|
An instance of a configuration object that defines the decoder config.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
|
||||||
|
|
||||||
|
# Initializing a BERT bert-base-uncased style configuration
|
||||||
|
config_encoder = BertConfig()
|
||||||
|
config_decoder = BertConfig()
|
||||||
|
|
||||||
|
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
|
||||||
|
|
||||||
|
# Initializing a Bert2Bert model from the bert-base-uncased style configurations
|
||||||
|
model = EncoderDecoderModel(config=config)
|
||||||
|
|
||||||
|
# Accessing the model configuration
|
||||||
|
config_encoder = model.config.encoder
|
||||||
|
config_decoder = model.config.decoder
|
||||||
"""
|
"""
|
||||||
model_type = "encoder_decoder"
|
model_type = "encoder_decoder"
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
|
from tranformers import EncoderDecoder
|
||||||
|
|
||||||
model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -230,6 +232,25 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||||
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
|
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
|
||||||
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
|
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
from transformers import EncoderDecoderModel, BertTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||||
|
|
||||||
|
# forward
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||||
|
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||||
|
|
||||||
|
# training
|
||||||
|
loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, lm_labels=input_ids)[:2]
|
||||||
|
|
||||||
|
# generation
|
||||||
|
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||||
|
|||||||
Reference in New Issue
Block a user