From 9c1bdb5b61303bbdfbc3b9759f5c5fa847cb377d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 30 Oct 2019 10:43:13 +0100 Subject: [PATCH] revert renaming of lm_labels to ltr_lm_labels --- examples/run_summarization_finetuning.py | 6 +++--- transformers/modeling_bert.py | 14 +++++++------- transformers/modeling_seq2seq.py | 22 +++++++++++++--------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index 2dc8c660ce..3d194950c7 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""): model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): - source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch + source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch source = source.to(args.device) target = target.to(args.device) encoder_token_type_ids = encoder_token_type_ids.to(args.device) encoder_mask = encoder_mask.to(args.device) decoder_mask = decoder_mask.to(args.device) - ltr_lm_labels = ltr_lm_labels.to(args.device) + lm_labels = lm_labels.to(args.device) with torch.no_grad(): outputs = model( @@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""): encoder_token_type_ids=encoder_token_type_ids, encoder_attention_mask=encoder_mask, decoder_attention_mask=decoder_mask, - decoder_ltr_lm_labels=ltr_lm_labels, + decoder_lm_labels=lm_labels, ) lm_loss = outputs[0] eval_loss += lm_loss.mean().item() diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 3fec69a814..11fcdde685 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -791,7 +791,7 @@ class BertForMaskedLM(BertPreTrainedModel): Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - **ltr_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels @@ -800,7 +800,7 @@ class BertForMaskedLM(BertPreTrainedModel): Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Masked language modeling loss. - **ltr_lm_loss**: (`optional`, returned when ``ltr_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + **ltr_lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Next token prediction loss. **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). @@ -838,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert.embeddings.word_embeddings) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, - masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, ltr_lm_labels=None, ): + masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): outputs = self.bert(input_ids, attention_mask=attention_mask, @@ -857,19 +857,19 @@ class BertForMaskedLM(BertPreTrainedModel): # 1. If a tensor that contains the indices of masked labels is provided, # the cross-entropy is the MLM cross-entropy that measures the likelihood # of predictions for masked words. - # 2. If `ltr_lm_labels` is provided we are in a causal scenario where we + # 2. If `lm_labels` is provided we are in a causal scenario where we # try to predict the next token for each input in the decoder. if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) outputs = (masked_lm_loss,) + outputs - if ltr_lm_labels is not None: + if lm_labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one prediction_scores = prediction_scores[:, :-1, :].contiguous() - ltr_lm_labels = ltr_lm_labels[:, 1:].contiguous() + lm_labels = lm_labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-1) - ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), ltr_lm_labels.view(-1)) + ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) outputs = (ltr_lm_loss,) + outputs return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 22898db9a1..ba8c546a30 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -30,10 +30,10 @@ logger = logging.getLogger(__name__) class PreTrainedSeq2seq(nn.Module): r""" :class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be - instantiated as a Seq2seq model with one of the base model classes of - the library as encoder and (optionally) as decoder when created with - the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class - method. + instantiated as a transformer architecture with one of the base model + classes of the library as encoder and (optionally) another one as + decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` + class method. """ def __init__(self, encoder, decoder): @@ -59,13 +59,13 @@ class PreTrainedSeq2seq(nn.Module): encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``. - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args: (`optional`) Sequence of positional arguments: @@ -103,7 +103,7 @@ class PreTrainedSeq2seq(nn.Module): - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - You can specify different kwargs for the decoder by prefixing the key with `decoder_` (e.g. ``decoder_output_attention=True``). + You can specify kwargs sepcific for the encoder and decoder by prefixing the key with `encoder_` and `decoder_` respectively. (e.g. ``decoder_output_attention=True``). The remaining kwargs will be passed to both encoders and decoders. Examples:: @@ -154,8 +154,11 @@ class PreTrainedSeq2seq(nn.Module): return model def save_pretrained(self, save_directory): - """ Save a Seq2Seq model and its configuration file in a format - such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """ + """ Save a Seq2Seq model and its configuration file in a format such + that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` + + We save the encoder' and decoder's parameters in two separate directories. + """ self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) @@ -176,6 +179,7 @@ class PreTrainedSeq2seq(nn.Module): Indices of encoder input sequence tokens in the vocabulary. decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` Indices of decoder input sequence tokens in the vocabulary. + kwargs: (`optional`) Remaining dictionary of keyword arguments. """ # keyword arguments come in 3 flavors: encoder-specific (prefixed by # `encoder_`), decoder-specific (prefixed by `decoder_`) and those