diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index 0b9873f910..2fafbebb8d 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -27,18 +27,35 @@ class EncoderDecoderConfig(PretrainedConfig): r""" :class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a `EncoderDecoderModel`. - It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs. - Configuration objects inherit from :class:`~transformers.PretrainedConfig` - and can be used to control the model outputs. - See the documentation for :class:`~transformers.PretrainedConfig` for more information. + It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs. + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. + See the documentation for :class:`~transformers.PretrainedConfig` for more information. + Args: + kwargs (`optional`): + Remaining dictionary of keyword arguments. Notably: + encoder (:class:`PretrainedConfig`, optional, defaults to `None`): + An instance of a configuration object that defines the encoder config. + encoder (:class:`PretrainedConfig`, optional, defaults to `None`): + An instance of a configuration object that defines the decoder config. - Arguments: - kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably: - encoder (:class:`PretrainedConfig`, optional, defaults to `None`): - An instance of a configuration object that defines the encoder config. - encoder (:class:`PretrainedConfig`, optional, defaults to `None`): - An instance of a configuration object that defines the decoder config. + Example:: + + from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel + + # Initializing a BERT bert-base-uncased style configuration + config_encoder = BertConfig() + config_decoder = BertConfig() + + config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + # Initializing a Bert2Bert model from the bert-base-uncased style configurations + model = EncoderDecoderModel(config=config) + + # Accessing the model configuration + config_encoder = model.config.encoder + config_decoder = model.config.decoder """ model_type = "encoder_decoder" diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 02b95c7baa..451edc6c03 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -125,6 +125,8 @@ class EncoderDecoderModel(PreTrainedModel): Examples:: + from tranformers import EncoderDecoder + model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert """ @@ -230,6 +232,25 @@ class EncoderDecoderModel(PreTrainedModel): kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. + + Examples:: + + from transformers import EncoderDecoderModel, BertTokenizer + import torch + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert + + # forward + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + + # training + loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, lm_labels=input_ids)[:2] + + # generation + generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id) + """ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}