`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
Sam Shleifer
2020-03-02 10:35:53 -05:00
committed by GitHub
parent 6b1ff25084
commit b54ef78d0c
8 changed files with 544 additions and 152 deletions

View File

@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bart-large": _bart_large_url,
"bart-large-mnli": _bart_large_url, # fine as same
"bart-cnn": None, # not done
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
}
@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
**common_kwargs
):
r"""
@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
config = BartConfig.from_pretrained('bart-large')
model = BartModel(config)
"""
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads

View File

@@ -23,9 +23,11 @@ import fairseq
import torch
from packaging import version
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0")
@@ -33,7 +35,7 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
SAMPLE_TEXT = " Hello world! cécé herlolip"
rename_keys = [
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
@@ -41,7 +43,7 @@ rename_keys = [
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"]
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
def rename_key(dct, old, new):
@@ -53,36 +55,45 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our BERT structure.
"""
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
b2.eval() # disable dropout
b2.model.upgrade_state_dict(b2.model.state_dict())
config = BartConfig()
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0)
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
bart.eval() # disable dropout
bart.model.upgrade_state_dict(bart.model.state_dict())
hf_model_name = checkpoint_path.replace(".", "-")
config = BartConfig.from_pretrained(hf_model_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()
# assert their_output.size() == (1, 11, 1024)
if checkpoint_path == "bart.large":
state_dict = b2.model.state_dict()
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
state_dict = bart.model.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = BartModel(config)
their_output = b2.extract_features(tokens)
their_output = bart.extract_features(tokens)
else: # MNLI Case
state_dict = b2.state_dict()
state_dict = bart.state_dict()
for k in IGNORE_KEYS:
state_dict.pop(k, None)
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
state_dict.pop("_float_tensor", None)
model = BartForSequenceClassification(config)
their_output = b2.predict("mnli", tokens, return_logits=True)
for k in IGNORE_KEYS:
state_dict.pop(k, None)
their_output = bart.predict("mnli", tokens, return_logits=True)
# Load state dict
model.load_state_dict(state_dict)
model.eval()
our_outputs = model.forward(tokens)[0]
# Check results
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
model = BartForMaskedLM(config, base_model=model)
assert "lm_head.weight" in model.state_dict()
assert model.lm_head.out_features == config.max_position_embeddings
model.eval()
our_outputs = model.model.forward(tokens)[0]
else:
our_outputs = model.forward(tokens)[0]
assert their_output.shape == our_outputs.shape
assert (their_output == our_outputs).all().item()
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
@@ -92,7 +103,8 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="")
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_bart_checkpoint(

View File

@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
@@ -24,7 +24,7 @@ from torch import Tensor, nn
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
@@ -332,7 +333,7 @@ class DecoderLayer(nn.Module):
x,
encoder_hidden_states,
encoder_attn_mask=None,
decoder_cached_states=None,
layer_state=None,
attention_mask=None,
need_attn_weights=False,
):
@@ -348,43 +349,28 @@ class DecoderLayer(nn.Module):
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if decoder_cached_states is None:
prev_self_attn_state, prev_attn_state = (None, None)
else:
assert len(decoder_cached_states) == 3
prev_self_attn_state, prev_attn_state = (
decoder_cached_states["self"],
decoder_cached_states["encoder_decoder"],
)
residual = x
if prev_self_attn_state is not None:
saved_state = prev_self_attn_state
decoder_cached_states["self"] = saved_state
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
if layer_state is None:
layer_state = {}
# next line mutates layer state
x, self_attn_weights = self.self_attn.forward(
query=x,
key=y,
value=y,
decoder_cached_states=decoder_cached_states,
need_weights=need_attn_weights,
attn_mask=attention_mask,
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if prev_attn_state is not None:
saved_state = prev_attn_state
decoder_cached_states["encoder_decoder"] = saved_state
x, encoder_attn_weights = self.encoder_attn.forward(
query=x,
key=encoder_hidden_states, # could be None
value=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
decoder_cached_states=decoder_cached_states,
layer_state=layer_state, # mutates layer state
static_kv=True,
need_weights=False, # not returning it so why compute it
)
@@ -403,15 +389,8 @@ class DecoderLayer(nn.Module):
return (
x,
self_attn_weights,
decoder_cached_states,
) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding
def _past_to_dict(self, prev_attn_state):
prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
return saved_state
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class BartDecoder(nn.Module):
@@ -440,6 +419,7 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.generation_mode = False
def forward(
self,
@@ -469,11 +449,15 @@ class BartDecoder(nn.Module):
- attentions
"""
# embed positions
positions = self.embed_positions(input_ids)
x = self.embed_tokens(input_ids)
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
if positions is not None:
x += positions
if self.generation_mode:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids)
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
@@ -489,17 +473,19 @@ class BartDecoder(nn.Module):
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward(
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_cached_states=layer_state,
layer_state=layer_state,
attention_mask=combined_mask,
need_attn_weights=self.output_attentions,
)
if self.output_past:
next_decoder_cache.append(layer_past)
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.output_attentions:
@@ -509,7 +495,22 @@ class BartDecoder(nn.Module):
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
return x, next_decoder_cache, all_hidden_states, list(all_self_attns)
if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
return x, next_cache, all_hidden_states, list(all_self_attns)
def reorder_attn_buffer(input_buffer, new_order):
"""Reorder buffered internal state (for incremental generation)."""
# input_buffer = self._get_input_buffer(incremental_state)
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return input_buffer
class SelfAttention(nn.Module):
@@ -557,7 +558,7 @@ class SelfAttention(nn.Module):
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = False,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
@@ -579,8 +580,8 @@ class SelfAttention(nn.Module):
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if decoder_cached_states is not None: # get the last k,v and mask for reuse
saved_state = decoder_cached_states.get(self.cache_key, {})
if layer_state is not None: # get the last k,v and mask for reuse
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
@@ -588,6 +589,7 @@ class SelfAttention(nn.Module):
key = value = None
else:
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if self.encoder_decoder_attention:
@@ -608,17 +610,16 @@ class SelfAttention(nn.Module):
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask, new_state = self._use_and_update_saved_state(
k, v, saved_state, key_padding_mask, static_kv, bsz
)
saved_state.update(
{
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask,
}
)
decoder_cached_states[self.cache_key] = saved_state # Update cache
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
@@ -632,7 +633,7 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len)
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
@@ -650,7 +651,7 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
return attn_output, attn_weights
def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
@@ -675,7 +676,7 @@ class SelfAttention(nn.Module):
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
return k, v, key_padding_mask, saved_state
return k, v, key_padding_mask
@staticmethod
def _cat_prev_key_padding_mask(
@@ -693,7 +694,6 @@ class SelfAttention(nn.Module):
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
@@ -747,9 +747,13 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input):
def forward(self, input, generation_mode=False):
"""Input is expected to be of size [bsz x seqlen]."""
positions = create_position_ids_from_input_ids(input, self.padding_idx)
if generation_mode: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
positions = create_position_ids_from_input_ids(input, self.padding_idx)
return super().forward(positions)
@@ -826,21 +830,20 @@ class BartModel(PretrainedBartModel):
assert attention_mask.max() <= 0
# make masks if user doesn't supply
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
)
if not self.decoder.generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
)
assert decoder_input_ids is not None
if encoder_outputs is None:
# TODO(SS): make this caching more usable when overwrite generate
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
assert isinstance(encoder_outputs, tuple)
# dec_features, decoder_cached_states, dec_hidden, dec_attn
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder.forward(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_attn_mask,
decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
)
# Attention and hidden_states will be [] or None if they aren't needed
@@ -856,20 +859,26 @@ class BartModel(PretrainedBartModel):
self.shared = value
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared)
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings(
"The bare BART Model with a language modeling head", BART_START_DOCSTRING,
"The bare BART Model with a language modeling head. This is the model used for summarization.",
BART_START_DOCSTRING,
)
class BartForMaskedLM(PretrainedBartModel):
base_model_prefix = "model"
def __init__(self, config: BartConfig):
super().__init__(config)
self.model = BartModel(config)
# if base_model is None:
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
def tie_weights(self):
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
self,
@@ -935,12 +944,309 @@ class BartForMaskedLM(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation(input_ids, past, **kwargs):
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": input_ids, # ignored after first pass
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
# "decoder_attention_mask": decoder_attention_mask,
}
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
reordered_past = []
for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
def get_output_embeddings(self):
return self.lm_head
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
max_length=20,
num_beams=1,
repetition_penalty=1.0,
length_penalty=1.0,
num_return_sequences=1,
min_len=0,
no_repeat_ngram_size=0,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
.. _`XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
.. _`Fairseq beam search code`:
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
min_len: (`optional`) int
Returns:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is <= max_length (examples can finish early)
Examples::
config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
"""
bos_token_id = self.config.bos_token_id
pad_token_id = self.config.pad_token_id
eos_token_id = self.config.eos_token_id
batch_size, cur_len = input_ids.shape
assert input_ids is not None
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(pad_token_id, int)
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a positive integer."
# current position and vocab size
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
batch_size *= num_return_sequences
# Below here somewhat similar to PretrainedModel._generate_beam_search
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
if attention_mask is not None:
attention_mask = (
attention_mask.unsqueeze(1)
.expand(batch_size, num_beams, cur_len)
.contiguous()
.view(batch_size * num_beams, cur_len)
) # RESHAPE
# generated hypotheses
finalized_hyps = [ # they end in EOS and we wont work on them more!
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 # avoid ties in first step
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# decoder tokens
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
decoder_cache = None
done = [False for _ in range(batch_size)] # done sentences
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
lprobs[lprobs != lprobs] = -math.inf # block nans
lprobs[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
if step == 0: # Force BOS to be chosen
lprobs[:, bos_token_id + 1 :] = -math.inf
elif step < min_len: # Prevent EOS from being chosen
lprobs[:, eos_token_id] = -math.inf
elif step == max_length: # FORCE EOS to be chosen
lprobs[:, :eos_token_id] = -math.inf
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams
if no_repeat_ngram_size > 0: # copied from fairseq
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
# then set their probabilities tof -inf
for idx in range(num_hypos):
lprobs[idx, banned_tokens[idx]] = -math.inf
assert lprobs.size() == (batch_size * num_beams, vocab_size)
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis across beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# list of (batch_size * num_beams)
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
for batch_idx in range(batch_size):
# if we are done with this sentence (because we can't improve)
if done[batch_idx]: # then pad all associated hypotheses
assert (
len(finalized_hyps[batch_idx]) >= num_beams
), "Example can only be done if at least {} beams have been generated".format(num_beams)
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# Otherwise generate some next word choices
next_sent_beam = []
# add next words for this sentence
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
beam_id = idx // vocab_size
word_id = idx % vocab_size
assert prev_output_tokens.shape[1] == (step + 1)
if word_id.item() == eos_token_id:
if i >= num_beams:
continue
finalized_hyps[batch_idx].add(
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
)
else:
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
break
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=step + 1,
)
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order decoder inputs to [beam_idx]
prev_output_tokens = prev_output_tokens[beam_idx]
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
# re-order internal states
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if done[batch_idx]:
continue
offset = batch_idx * num_beams
for i in range(num_beams):
score = beam_scores[offset + i]
final_tokens = prev_output_tokens[offset + i]
finalized_hyps[batch_idx].add(final_tokens, score.item())
# select the best hypotheses
sent_lengths = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(finalized_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
sent_lengths[i] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded[:, 1:] # get rid of starting EOS
@staticmethod
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
"""Copied from fairseq for no_repeat_ngram in beam_search"""
# TODO(SS): this can go on parent if there is demand
if step + 2 < no_repeat_ngram_size:
return [
[] for _ in range(num_hypos)
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
gen_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_output_tokens[idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
k = tuple(ngram[:-1])
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
return gen_ngrams[hypo_idx].get(ngram_index, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,

View File

@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else:
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure word embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
@@ -574,16 +573,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False
if mem_len > 0 or has_output_past:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
@torch.no_grad()
def generate(
self,
@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # (batch_size * num_return_sequences, cur_len)
) # shape: (batch_size * num_return_sequences, cur_len)
effective_batch_size = batch_size * num_return_sequences
else:
effective_batch_size = batch_size
@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_token] < 0:
next_token_logits[i, previous_token] *= repetition_penalty
else:
next_token_logits[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
# assert input_ids.shape == (batch_size * num_beams, cur_len)
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size * num_beams):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= repetition_penalty
else:
scores[i, previous_token] /= repetition_penalty
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class Conv1D(nn.Module):

View File

@@ -19,11 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_bart_models = [
"bart-large",
"bart-large-mnli",
# "bart-large-cnn"
]
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
class BartTokenizer(RobertaTokenizer):