BartForCausalLM analogs to ProphetNetForCausalLM (#9128)
* initiliaze bart4causalLM * create BartDecoderWrapper, setters/getters * delete spaces * forward and additional methods * update cache function, loss function, remove ngram* params in data class. * add bartcausallm, bartdecoder testing * correct bart for causal lm * remove at * add mbart as well * up * fix typo * up * correct * add pegasusforcausallm * add blenderbotforcausallm * add blenderbotsmallforcausallm * add marianforcausallm * add test for MarianForCausalLM * add Pegasus test * add BlenderbotSmall test * add blenderbot test * fix a fail * fix an import fail * a fix * fix * Update modeling_pegasus.py * fix models * fix inputs_embeds setting getter * adapt tests * correct repo utils check * finish test improvement * fix tf models as well * make style * make fix-copies * fix copies * run all tests * last changes * fix all tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
"BlenderbotSmallForConditionalGeneration",
|
||||
"BlenderbotSmallModel",
|
||||
"BlenderbotSmallPreTrainedModel",
|
||||
"BlenderbotSmallForCausalLM",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
@@ -46,6 +47,7 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_blenderbot_small import (
|
||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotSmallForCausalLM,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
BlenderbotSmallModel,
|
||||
BlenderbotSmallPreTrainedModel,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
""" PyTorch BlenderbotSmall model. """
|
||||
|
||||
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
@@ -35,6 +36,7 @@ from ...file_utils import (
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
@@ -805,6 +807,31 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -907,19 +934,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
@@ -938,7 +955,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
if head_mask is not None:
|
||||
@@ -974,7 +991,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
@@ -985,7 +1002,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
@@ -1001,7 +1018,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
@@ -1310,3 +1329,210 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall
|
||||
class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||
"""
|
||||
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.decoder = BlenderbotSmallDecoder(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall
|
||||
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
self.model = BlenderbotSmallDecoderWrapper(config)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.decoder.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.decoder = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||
config.vocab_size]``.
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||
for more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotSmallForCausalLM
|
||||
|
||||
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained('facebook/bart-large')
|
||||
>>> model = BlenderbotSmallForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
@@ -925,7 +925,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||
if inputs["attention_mask"] is not None:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user