Bart-CNN (#3059)
`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model, ported from the fairseq repo."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -24,7 +24,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .configuration_bart import BartConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
||||
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
}
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
@@ -332,7 +333,7 @@ class DecoderLayer(nn.Module):
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=None,
|
||||
decoder_cached_states=None,
|
||||
layer_state=None,
|
||||
attention_mask=None,
|
||||
need_attn_weights=False,
|
||||
):
|
||||
@@ -348,43 +349,28 @@ class DecoderLayer(nn.Module):
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
if decoder_cached_states is None:
|
||||
prev_self_attn_state, prev_attn_state = (None, None)
|
||||
else:
|
||||
assert len(decoder_cached_states) == 3
|
||||
prev_self_attn_state, prev_attn_state = (
|
||||
decoder_cached_states["self"],
|
||||
decoder_cached_states["encoder_decoder"],
|
||||
)
|
||||
|
||||
residual = x
|
||||
if prev_self_attn_state is not None:
|
||||
saved_state = prev_self_attn_state
|
||||
decoder_cached_states["self"] = saved_state
|
||||
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
|
||||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn.forward(
|
||||
query=x,
|
||||
key=y,
|
||||
value=y,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
need_weights=need_attn_weights,
|
||||
attn_mask=attention_mask,
|
||||
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
if prev_attn_state is not None:
|
||||
saved_state = prev_attn_state
|
||||
decoder_cached_states["encoder_decoder"] = saved_state
|
||||
|
||||
x, encoder_attn_weights = self.encoder_attn.forward(
|
||||
query=x,
|
||||
key=encoder_hidden_states, # could be None
|
||||
value=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
static_kv=True,
|
||||
need_weights=False, # not returning it so why compute it
|
||||
)
|
||||
@@ -403,15 +389,8 @@ class DecoderLayer(nn.Module):
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
decoder_cached_states,
|
||||
) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding
|
||||
|
||||
def _past_to_dict(self, prev_attn_state):
|
||||
prev_key, prev_value = prev_attn_state[:2]
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
if len(prev_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
return saved_state
|
||||
layer_state,
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
|
||||
|
||||
class BartDecoder(nn.Module):
|
||||
@@ -440,6 +419,7 @@ class BartDecoder(nn.Module):
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.generation_mode = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -469,11 +449,15 @@ class BartDecoder(nn.Module):
|
||||
- attentions
|
||||
"""
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids)
|
||||
x = self.embed_tokens(input_ids)
|
||||
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
if self.generation_mode:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
x += positions
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@@ -489,17 +473,19 @@ class BartDecoder(nn.Module):
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
x, layer_self_attn, layer_past = decoder_layer.forward(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_padding_mask,
|
||||
decoder_cached_states=layer_state,
|
||||
layer_state=layer_state,
|
||||
attention_mask=combined_mask,
|
||||
need_attn_weights=self.output_attentions,
|
||||
)
|
||||
|
||||
if self.output_past:
|
||||
next_decoder_cache.append(layer_past)
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
if self.output_attentions:
|
||||
@@ -509,7 +495,22 @@ class BartDecoder(nn.Module):
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, next_decoder_cache, all_hidden_states, list(all_self_attns)
|
||||
if self.output_past:
|
||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||
else:
|
||||
next_cache = None
|
||||
return x, next_cache, all_hidden_states, list(all_self_attns)
|
||||
|
||||
|
||||
def reorder_attn_buffer(input_buffer, new_order):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
# input_buffer = self._get_input_buffer(incremental_state)
|
||||
for k in input_buffer.keys():
|
||||
input_buffer_k = input_buffer[k]
|
||||
if input_buffer_k is not None:
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
return input_buffer
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -557,7 +558,7 @@ class SelfAttention(nn.Module):
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = False,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
@@ -579,8 +580,8 @@ class SelfAttention(nn.Module):
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
# get here for encoder decoder cause of static_kv
|
||||
if decoder_cached_states is not None: # get the last k,v and mask for reuse
|
||||
saved_state = decoder_cached_states.get(self.cache_key, {})
|
||||
if layer_state is not None: # get the last k,v and mask for reuse
|
||||
saved_state = layer_state.get(self.cache_key, {})
|
||||
if "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute key and value if they are static
|
||||
if static_kv:
|
||||
@@ -588,6 +589,7 @@ class SelfAttention(nn.Module):
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
layer_state = {}
|
||||
|
||||
q = self.q_proj(query) * self.scaling
|
||||
if self.encoder_decoder_attention:
|
||||
@@ -608,17 +610,16 @@ class SelfAttention(nn.Module):
|
||||
v = self._shape(v, -1, bsz)
|
||||
|
||||
if saved_state is not None:
|
||||
k, v, key_padding_mask, new_state = self._use_and_update_saved_state(
|
||||
k, v, saved_state, key_padding_mask, static_kv, bsz
|
||||
)
|
||||
saved_state.update(
|
||||
{
|
||||
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_key_padding_mask": key_padding_mask,
|
||||
}
|
||||
)
|
||||
decoder_cached_states[self.cache_key] = saved_state # Update cache
|
||||
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
||||
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
|
||||
|
||||
# Update cache
|
||||
layer_state[self.cache_key] = {
|
||||
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
||||
}
|
||||
|
||||
assert k is not None
|
||||
src_len = k.size(1)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
@@ -632,7 +633,7 @@ class SelfAttention(nn.Module):
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len)
|
||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
|
||||
|
||||
if key_padding_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
@@ -650,7 +651,7 @@ class SelfAttention(nn.Module):
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
@@ -675,7 +676,7 @@ class SelfAttention(nn.Module):
|
||||
key_padding_mask = self._cat_prev_key_padding_mask(
|
||||
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
||||
)
|
||||
return k, v, key_padding_mask, saved_state
|
||||
return k, v, key_padding_mask
|
||||
|
||||
@staticmethod
|
||||
def _cat_prev_key_padding_mask(
|
||||
@@ -693,7 +694,6 @@ class SelfAttention(nn.Module):
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
|
||||
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
|
||||
if prev_key_padding_mask.is_cuda:
|
||||
filler = filler.cuda()
|
||||
@@ -747,9 +747,13 @@ class LearnedPositionalEmbedding(nn.Embedding):
|
||||
num_embeddings += padding_idx + 1 # WHY?
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input, generation_mode=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
||||
if generation_mode: # the position is our current step in the decoded sequence
|
||||
pos = int(self.padding_idx + input.size(1))
|
||||
positions = input.data.new(1, 1).fill_(pos)
|
||||
else:
|
||||
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
||||
return super().forward(positions)
|
||||
|
||||
|
||||
@@ -826,21 +830,20 @@ class BartModel(PretrainedBartModel):
|
||||
assert attention_mask.max() <= 0
|
||||
|
||||
# make masks if user doesn't supply
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
|
||||
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
if not self.decoder.generation_mode:
|
||||
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
|
||||
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
|
||||
)
|
||||
assert decoder_input_ids is not None
|
||||
if encoder_outputs is None:
|
||||
# TODO(SS): make this caching more usable when overwrite generate
|
||||
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
|
||||
assert isinstance(encoder_outputs, tuple)
|
||||
# dec_features, decoder_cached_states, dec_hidden, dec_attn
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder.forward(
|
||||
decoder_input_ids,
|
||||
encoder_outputs[0],
|
||||
attention_mask,
|
||||
decoder_attn_mask,
|
||||
decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
)
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
@@ -856,20 +859,26 @@ class BartModel(PretrainedBartModel):
|
||||
self.shared = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.shared)
|
||||
return _make_linear_from_emb(self.shared) # make it on the fly
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BART Model with a language modeling head", BART_START_DOCSTRING,
|
||||
"The bare BART Model with a language modeling head. This is the model used for summarization.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForMaskedLM(PretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
self.model = BartModel(config)
|
||||
# if base_model is None:
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.lm_head = _make_linear_from_emb(self.model.shared)
|
||||
|
||||
def tie_weights(self):
|
||||
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@@ -935,12 +944,309 @@ class BartForMaskedLM(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(input_ids, past, **kwargs):
|
||||
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
|
||||
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
return {
|
||||
"input_ids": input_ids, # ignored after first pass
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
# "decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_cached_states) = past
|
||||
reordered_past = []
|
||||
for layer_past in decoder_cached_states:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
reordered_past.append(layer_past_new)
|
||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
|
||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
||||
|
||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||
return past
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
max_length=20,
|
||||
num_beams=1,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
num_return_sequences=1,
|
||||
min_len=0,
|
||||
no_repeat_ngram_size=0,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
|
||||
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
|
||||
|
||||
.. _`XLM beam search code`:
|
||||
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||
.. _`Fairseq beam search code`:
|
||||
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
|
||||
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
num_return_sequences: (`optional`) int
|
||||
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
min_len: (`optional`) int
|
||||
|
||||
Returns:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
sequence_length is <= max_length (examples can finish early)
|
||||
|
||||
Examples::
|
||||
|
||||
config = BartConfig(vocab_size=50264, output_past=True)
|
||||
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
|
||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
||||
# Generate Summary
|
||||
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
||||
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
|
||||
|
||||
"""
|
||||
bos_token_id = self.config.bos_token_id
|
||||
pad_token_id = self.config.pad_token_id
|
||||
eos_token_id = self.config.eos_token_id
|
||||
batch_size, cur_len = input_ids.shape
|
||||
assert input_ids is not None
|
||||
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
assert isinstance(pad_token_id, int)
|
||||
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
|
||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
), "`num_return_sequences` should be a positive integer."
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1:
|
||||
# Expand input to num return sequences
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||
input_ids = input_ids.contiguous().view(
|
||||
batch_size * num_return_sequences, cur_len
|
||||
) # shape: (batch_size * num_return_sequences, cur_len)
|
||||
batch_size *= num_return_sequences
|
||||
|
||||
# Below here somewhat similar to PretrainedModel._generate_beam_search
|
||||
# Expand input to num beams
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
if attention_mask is not None:
|
||||
attention_mask = (
|
||||
attention_mask.unsqueeze(1)
|
||||
.expand(batch_size, num_beams, cur_len)
|
||||
.contiguous()
|
||||
.view(batch_size * num_beams, cur_len)
|
||||
) # RESHAPE
|
||||
|
||||
# generated hypotheses
|
||||
finalized_hyps = [ # they end in EOS and we wont work on them more!
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9 # avoid ties in first step
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# decoder tokens
|
||||
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
|
||||
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
|
||||
decoder_cache = None
|
||||
done = [False for _ in range(batch_size)] # done sentences
|
||||
|
||||
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
||||
for step in range(max_length + 1):
|
||||
decoder_input_ids = prev_output_tokens.clone()
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
|
||||
|
||||
lprobs[lprobs != lprobs] = -math.inf # block nans
|
||||
lprobs[:, pad_token_id] = -math.inf
|
||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||
|
||||
if step == 0: # Force BOS to be chosen
|
||||
lprobs[:, bos_token_id + 1 :] = -math.inf
|
||||
elif step < min_len: # Prevent EOS from being chosen
|
||||
lprobs[:, eos_token_id] = -math.inf
|
||||
elif step == max_length: # FORCE EOS to be chosen
|
||||
lprobs[:, :eos_token_id] = -math.inf
|
||||
lprobs[:, eos_token_id + 1 :] = -math.inf
|
||||
assert self._do_output_past(outputs)
|
||||
decoder_cache = outputs[1]
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
|
||||
num_hypos = batch_size * num_beams
|
||||
if no_repeat_ngram_size > 0: # copied from fairseq
|
||||
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
|
||||
# then set their probabilities tof -inf
|
||||
for idx in range(num_hypos):
|
||||
lprobs[idx, banned_tokens[idx]] = -math.inf
|
||||
assert lprobs.size() == (batch_size * num_beams, vocab_size)
|
||||
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis across beams)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
|
||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||
|
||||
# list of (batch_size * num_beams)
|
||||
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
|
||||
for batch_idx in range(batch_size):
|
||||
# if we are done with this sentence (because we can't improve)
|
||||
if done[batch_idx]: # then pad all associated hypotheses
|
||||
assert (
|
||||
len(finalized_hyps[batch_idx]) >= num_beams
|
||||
), "Example can only be done if at least {} beams have been generated".format(num_beams)
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
# Otherwise generate some next word choices
|
||||
next_sent_beam = []
|
||||
# add next words for this sentence
|
||||
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
assert prev_output_tokens.shape[1] == (step + 1)
|
||||
if word_id.item() == eos_token_id:
|
||||
if i >= num_beams:
|
||||
continue
|
||||
finalized_hyps[batch_idx].add(
|
||||
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
||||
|
||||
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
|
||||
break
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len=step + 1,
|
||||
)
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
# re-order decoder inputs to [beam_idx]
|
||||
prev_output_tokens = prev_output_tokens[beam_idx]
|
||||
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if done[batch_idx]:
|
||||
continue
|
||||
offset = batch_idx * num_beams
|
||||
for i in range(num_beams):
|
||||
score = beam_scores[offset + i]
|
||||
final_tokens = prev_output_tokens[offset + i]
|
||||
finalized_hyps[batch_idx].add(final_tokens, score.item())
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size)
|
||||
best = []
|
||||
for i, hypotheses in enumerate(finalized_hyps):
|
||||
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
||||
sent_lengths[i] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
|
||||
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
return decoded[:, 1:] # get rid of starting EOS
|
||||
|
||||
@staticmethod
|
||||
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
# TODO(SS): this can go on parent if there is demand
|
||||
if step + 2 < no_repeat_ngram_size:
|
||||
return [
|
||||
[] for _ in range(num_hypos)
|
||||
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
gen_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_output_tokens[idx].tolist()
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
||||
k = tuple(ngram[:-1])
|
||||
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
|
||||
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
|
||||
return gen_ngrams[hypo_idx].get(ngram_index, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
Reference in New Issue
Block a user