Bart-CNN (#3059)
`generate` code that produces 99% identical summarizations to fairseq on CNN test data, with caching.
This commit is contained in:
@@ -4,20 +4,27 @@ Bart
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||
@sshleifer
|
||||
|
||||
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019.
|
||||
It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results.
|
||||
The authors released their code `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
|
||||
Paper
|
||||
~~~~~
|
||||
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.
|
||||
According to the abstract:
|
||||
|
||||
**Abstract:**
|
||||
- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a left-to-right decoder (like GPT).
|
||||
- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, where spans of text are replaced with a single mask token.
|
||||
- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE.
|
||||
|
||||
*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.*
|
||||
`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension`
|
||||
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
|
||||
|
||||
|
||||
Notes:
|
||||
Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
|
||||
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
|
||||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
|
||||
- Decoder inputs are created automatically by the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``
|
||||
BartModel
|
||||
- ``MaskedLM.generate`` should be used for summarization, see the example in that docstrings
|
||||
|
||||
|
||||
BartModel
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -30,7 +37,7 @@ BartForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForMaskedLM
|
||||
:members: forward
|
||||
:members: forward, generate
|
||||
|
||||
|
||||
BartForSequenceClassification
|
||||
|
||||
@@ -280,7 +280,10 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||
| | | | bart-large base architecture with a classification head |
|
||||
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
|
||||
| | | | bart-large base architecture finetuned on cnn summarization task |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ _bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large": _bart_large_url,
|
||||
"bart-large-mnli": _bart_large_url, # fine as same
|
||||
"bart-cnn": None, # not done
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
|
||||
classifier_dropout=0.0,
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
bos_token_id=0,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -67,12 +68,16 @@ class BartConfig(PretrainedConfig):
|
||||
config = BartConfig.from_pretrained('bart-large')
|
||||
model = BartModel(config)
|
||||
"""
|
||||
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
|
||||
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
output_past=output_past,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
**common_kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
|
||||
@@ -23,9 +23,11 @@ import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer
|
||||
from transformers import BartConfig, BartForMaskedLM, BartForSequenceClassification, BartModel, BartTokenizer
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
@@ -33,7 +35,7 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
@@ -41,7 +43,7 @@ rename_keys = [
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"]
|
||||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
@@ -53,36 +55,45 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
|
||||
b2.eval() # disable dropout
|
||||
b2.model.upgrade_state_dict(b2.model.state_dict())
|
||||
config = BartConfig()
|
||||
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
|
||||
bart.eval() # disable dropout
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
hf_model_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_model_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
assert torch.eq(tokens, tokens2).all()
|
||||
|
||||
# assert their_output.size() == (1, 11, 1024)
|
||||
|
||||
if checkpoint_path == "bart.large":
|
||||
state_dict = b2.model.state_dict()
|
||||
if checkpoint_path in ["bart.large", "bart.large.cnn"]:
|
||||
state_dict = bart.model.state_dict()
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
model = BartModel(config)
|
||||
their_output = b2.extract_features(tokens)
|
||||
|
||||
their_output = bart.extract_features(tokens)
|
||||
else: # MNLI Case
|
||||
state_dict = b2.state_dict()
|
||||
state_dict = bart.state_dict()
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
state_dict.pop("_float_tensor", None)
|
||||
model = BartForSequenceClassification(config)
|
||||
their_output = b2.predict("mnli", tokens, return_logits=True)
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
their_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
|
||||
# Load state dict
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
our_outputs = model.forward(tokens)[0]
|
||||
# Check results
|
||||
|
||||
if checkpoint_path == "bart.large.cnn": # generate doesnt work yet
|
||||
model = BartForMaskedLM(config, base_model=model)
|
||||
assert "lm_head.weight" in model.state_dict()
|
||||
assert model.lm_head.out_features == config.max_position_embeddings
|
||||
model.eval()
|
||||
our_outputs = model.model.forward(tokens)[0]
|
||||
else:
|
||||
our_outputs = model.forward(tokens)[0]
|
||||
assert their_output.shape == our_outputs.shape
|
||||
assert (their_output == our_outputs).all().item()
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
@@ -92,7 +103,8 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="")
|
||||
parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
|
||||
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model, ported from the fairseq repo."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -24,7 +24,7 @@ from torch import Tensor, nn
|
||||
|
||||
from .configuration_bart import BartConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
||||
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
}
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
@@ -332,7 +333,7 @@ class DecoderLayer(nn.Module):
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_attn_mask=None,
|
||||
decoder_cached_states=None,
|
||||
layer_state=None,
|
||||
attention_mask=None,
|
||||
need_attn_weights=False,
|
||||
):
|
||||
@@ -348,43 +349,28 @@ class DecoderLayer(nn.Module):
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
if decoder_cached_states is None:
|
||||
prev_self_attn_state, prev_attn_state = (None, None)
|
||||
else:
|
||||
assert len(decoder_cached_states) == 3
|
||||
prev_self_attn_state, prev_attn_state = (
|
||||
decoder_cached_states["self"],
|
||||
decoder_cached_states["encoder_decoder"],
|
||||
)
|
||||
|
||||
residual = x
|
||||
if prev_self_attn_state is not None:
|
||||
saved_state = prev_self_attn_state
|
||||
decoder_cached_states["self"] = saved_state
|
||||
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
|
||||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
x, self_attn_weights = self.self_attn.forward(
|
||||
query=x,
|
||||
key=y,
|
||||
value=y,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
need_weights=need_attn_weights,
|
||||
attn_mask=attention_mask,
|
||||
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
if prev_attn_state is not None:
|
||||
saved_state = prev_attn_state
|
||||
decoder_cached_states["encoder_decoder"] = saved_state
|
||||
|
||||
x, encoder_attn_weights = self.encoder_attn.forward(
|
||||
query=x,
|
||||
key=encoder_hidden_states, # could be None
|
||||
value=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
static_kv=True,
|
||||
need_weights=False, # not returning it so why compute it
|
||||
)
|
||||
@@ -403,15 +389,8 @@ class DecoderLayer(nn.Module):
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
decoder_cached_states,
|
||||
) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding
|
||||
|
||||
def _past_to_dict(self, prev_attn_state):
|
||||
prev_key, prev_value = prev_attn_state[:2]
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
if len(prev_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
return saved_state
|
||||
layer_state,
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
|
||||
|
||||
class BartDecoder(nn.Module):
|
||||
@@ -440,6 +419,7 @@ class BartDecoder(nn.Module):
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.generation_mode = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -469,10 +449,14 @@ class BartDecoder(nn.Module):
|
||||
- attentions
|
||||
"""
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids)
|
||||
x = self.embed_tokens(input_ids)
|
||||
positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode)
|
||||
|
||||
if positions is not None:
|
||||
if self.generation_mode:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
x += positions
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
@@ -489,17 +473,19 @@ class BartDecoder(nn.Module):
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
x, layer_self_attn, layer_past = decoder_layer.forward(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
encoder_padding_mask,
|
||||
decoder_cached_states=layer_state,
|
||||
layer_state=layer_state,
|
||||
attention_mask=combined_mask,
|
||||
need_attn_weights=self.output_attentions,
|
||||
)
|
||||
|
||||
if self.output_past:
|
||||
next_decoder_cache.append(layer_past)
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
if self.output_attentions:
|
||||
@@ -509,7 +495,22 @@ class BartDecoder(nn.Module):
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, next_decoder_cache, all_hidden_states, list(all_self_attns)
|
||||
if self.output_past:
|
||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||
else:
|
||||
next_cache = None
|
||||
return x, next_cache, all_hidden_states, list(all_self_attns)
|
||||
|
||||
|
||||
def reorder_attn_buffer(input_buffer, new_order):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
# input_buffer = self._get_input_buffer(incremental_state)
|
||||
for k in input_buffer.keys():
|
||||
input_buffer_k = input_buffer[k]
|
||||
if input_buffer_k is not None:
|
||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
||||
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
return input_buffer
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -557,7 +558,7 @@ class SelfAttention(nn.Module):
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
need_weights: bool = False,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
@@ -579,8 +580,8 @@ class SelfAttention(nn.Module):
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
# get here for encoder decoder cause of static_kv
|
||||
if decoder_cached_states is not None: # get the last k,v and mask for reuse
|
||||
saved_state = decoder_cached_states.get(self.cache_key, {})
|
||||
if layer_state is not None: # get the last k,v and mask for reuse
|
||||
saved_state = layer_state.get(self.cache_key, {})
|
||||
if "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute key and value if they are static
|
||||
if static_kv:
|
||||
@@ -588,6 +589,7 @@ class SelfAttention(nn.Module):
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
layer_state = {}
|
||||
|
||||
q = self.q_proj(query) * self.scaling
|
||||
if self.encoder_decoder_attention:
|
||||
@@ -608,17 +610,16 @@ class SelfAttention(nn.Module):
|
||||
v = self._shape(v, -1, bsz)
|
||||
|
||||
if saved_state is not None:
|
||||
k, v, key_padding_mask, new_state = self._use_and_update_saved_state(
|
||||
k, v, saved_state, key_padding_mask, static_kv, bsz
|
||||
)
|
||||
saved_state.update(
|
||||
{
|
||||
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
||||
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
|
||||
|
||||
# Update cache
|
||||
layer_state[self.cache_key] = {
|
||||
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_key_padding_mask": key_padding_mask,
|
||||
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
||||
}
|
||||
)
|
||||
decoder_cached_states[self.cache_key] = saved_state # Update cache
|
||||
|
||||
assert k is not None
|
||||
src_len = k.size(1)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
@@ -632,7 +633,7 @@ class SelfAttention(nn.Module):
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len)
|
||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
|
||||
|
||||
if key_padding_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
@@ -650,7 +651,7 @@ class SelfAttention(nn.Module):
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
@@ -675,7 +676,7 @@ class SelfAttention(nn.Module):
|
||||
key_padding_mask = self._cat_prev_key_padding_mask(
|
||||
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
||||
)
|
||||
return k, v, key_padding_mask, saved_state
|
||||
return k, v, key_padding_mask
|
||||
|
||||
@staticmethod
|
||||
def _cat_prev_key_padding_mask(
|
||||
@@ -693,7 +694,6 @@ class SelfAttention(nn.Module):
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
|
||||
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
|
||||
if prev_key_padding_mask.is_cuda:
|
||||
filler = filler.cuda()
|
||||
@@ -747,8 +747,12 @@ class LearnedPositionalEmbedding(nn.Embedding):
|
||||
num_embeddings += padding_idx + 1 # WHY?
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input, generation_mode=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
if generation_mode: # the position is our current step in the decoded sequence
|
||||
pos = int(self.padding_idx + input.size(1))
|
||||
positions = input.data.new(1, 1).fill_(pos)
|
||||
else:
|
||||
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
||||
return super().forward(positions)
|
||||
|
||||
@@ -826,21 +830,20 @@ class BartModel(PretrainedBartModel):
|
||||
assert attention_mask.max() <= 0
|
||||
|
||||
# make masks if user doesn't supply
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
|
||||
if not self.decoder.generation_mode:
|
||||
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
|
||||
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
if encoder_outputs is None:
|
||||
# TODO(SS): make this caching more usable when overwrite generate
|
||||
encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
|
||||
assert isinstance(encoder_outputs, tuple)
|
||||
# dec_features, decoder_cached_states, dec_hidden, dec_attn
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder.forward(
|
||||
decoder_input_ids,
|
||||
encoder_outputs[0],
|
||||
attention_mask,
|
||||
decoder_attn_mask,
|
||||
decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
)
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
@@ -856,20 +859,26 @@ class BartModel(PretrainedBartModel):
|
||||
self.shared = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.shared)
|
||||
return _make_linear_from_emb(self.shared) # make it on the fly
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BART Model with a language modeling head", BART_START_DOCSTRING,
|
||||
"The bare BART Model with a language modeling head. This is the model used for summarization.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForMaskedLM(PretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
self.model = BartModel(config)
|
||||
# if base_model is None:
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.lm_head = _make_linear_from_emb(self.model.shared)
|
||||
|
||||
def tie_weights(self):
|
||||
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@@ -935,12 +944,309 @@ class BartForMaskedLM(PretrainedBartModel):
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(input_ids, past, **kwargs):
|
||||
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
|
||||
def prepare_inputs_for_generation(input_ids, past, decoder_input_ids, attention_mask):
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
return {
|
||||
"input_ids": input_ids, # ignored after first pass
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
# "decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_cached_states) = past
|
||||
reordered_past = []
|
||||
for layer_past in decoder_cached_states:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
reordered_past.append(layer_past_new)
|
||||
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
|
||||
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
|
||||
|
||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||
return past
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
max_length=20,
|
||||
num_beams=1,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
num_return_sequences=1,
|
||||
min_len=0,
|
||||
no_repeat_ngram_size=0,
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
|
||||
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
|
||||
|
||||
.. _`XLM beam search code`:
|
||||
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
||||
.. _`Fairseq beam search code`:
|
||||
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
|
||||
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape `(1,)`.
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
num_return_sequences: (`optional`) int
|
||||
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
min_len: (`optional`) int
|
||||
|
||||
Returns:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
sequence_length is <= max_length (examples can finish early)
|
||||
|
||||
Examples::
|
||||
|
||||
config = BartConfig(vocab_size=50264, output_past=True)
|
||||
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
|
||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
||||
# Generate Summary
|
||||
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
||||
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])
|
||||
|
||||
"""
|
||||
bos_token_id = self.config.bos_token_id
|
||||
pad_token_id = self.config.pad_token_id
|
||||
eos_token_id = self.config.eos_token_id
|
||||
batch_size, cur_len = input_ids.shape
|
||||
assert input_ids is not None
|
||||
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
assert isinstance(pad_token_id, int)
|
||||
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
|
||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
), "`num_return_sequences` should be a positive integer."
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1:
|
||||
# Expand input to num return sequences
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||
input_ids = input_ids.contiguous().view(
|
||||
batch_size * num_return_sequences, cur_len
|
||||
) # shape: (batch_size * num_return_sequences, cur_len)
|
||||
batch_size *= num_return_sequences
|
||||
|
||||
# Below here somewhat similar to PretrainedModel._generate_beam_search
|
||||
# Expand input to num beams
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
if attention_mask is not None:
|
||||
attention_mask = (
|
||||
attention_mask.unsqueeze(1)
|
||||
.expand(batch_size, num_beams, cur_len)
|
||||
.contiguous()
|
||||
.view(batch_size * num_beams, cur_len)
|
||||
) # RESHAPE
|
||||
|
||||
# generated hypotheses
|
||||
finalized_hyps = [ # they end in EOS and we wont work on them more!
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9 # avoid ties in first step
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# decoder tokens
|
||||
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
|
||||
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
|
||||
decoder_cache = None
|
||||
done = [False for _ in range(batch_size)] # done sentences
|
||||
|
||||
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
|
||||
for step in range(max_length + 1):
|
||||
decoder_input_ids = prev_output_tokens.clone()
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, decoder_cache, decoder_input_ids, attention_mask,
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
|
||||
|
||||
lprobs[lprobs != lprobs] = -math.inf # block nans
|
||||
lprobs[:, pad_token_id] = -math.inf
|
||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||
|
||||
if step == 0: # Force BOS to be chosen
|
||||
lprobs[:, bos_token_id + 1 :] = -math.inf
|
||||
elif step < min_len: # Prevent EOS from being chosen
|
||||
lprobs[:, eos_token_id] = -math.inf
|
||||
elif step == max_length: # FORCE EOS to be chosen
|
||||
lprobs[:, :eos_token_id] = -math.inf
|
||||
lprobs[:, eos_token_id + 1 :] = -math.inf
|
||||
assert self._do_output_past(outputs)
|
||||
decoder_cache = outputs[1]
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
|
||||
num_hypos = batch_size * num_beams
|
||||
if no_repeat_ngram_size > 0: # copied from fairseq
|
||||
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
|
||||
# then set their probabilities tof -inf
|
||||
for idx in range(num_hypos):
|
||||
lprobs[idx, banned_tokens[idx]] = -math.inf
|
||||
assert lprobs.size() == (batch_size * num_beams, vocab_size)
|
||||
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis across beams)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
|
||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||
|
||||
# list of (batch_size * num_beams)
|
||||
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
|
||||
for batch_idx in range(batch_size):
|
||||
# if we are done with this sentence (because we can't improve)
|
||||
if done[batch_idx]: # then pad all associated hypotheses
|
||||
assert (
|
||||
len(finalized_hyps[batch_idx]) >= num_beams
|
||||
), "Example can only be done if at least {} beams have been generated".format(num_beams)
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
# Otherwise generate some next word choices
|
||||
next_sent_beam = []
|
||||
# add next words for this sentence
|
||||
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
assert prev_output_tokens.shape[1] == (step + 1)
|
||||
if word_id.item() == eos_token_id:
|
||||
if i >= num_beams:
|
||||
continue
|
||||
finalized_hyps[batch_idx].add(
|
||||
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
||||
|
||||
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
|
||||
break
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len=step + 1,
|
||||
)
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
# re-order decoder inputs to [beam_idx]
|
||||
prev_output_tokens = prev_output_tokens[beam_idx]
|
||||
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if done[batch_idx]:
|
||||
continue
|
||||
offset = batch_idx * num_beams
|
||||
for i in range(num_beams):
|
||||
score = beam_scores[offset + i]
|
||||
final_tokens = prev_output_tokens[offset + i]
|
||||
finalized_hyps[batch_idx].add(final_tokens, score.item())
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size)
|
||||
best = []
|
||||
for i, hypotheses in enumerate(finalized_hyps):
|
||||
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
||||
sent_lengths[i] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
|
||||
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
return decoded[:, 1:] # get rid of starting EOS
|
||||
|
||||
@staticmethod
|
||||
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
# TODO(SS): this can go on parent if there is demand
|
||||
if step + 2 < no_repeat_ngram_size:
|
||||
return [
|
||||
[] for _ in range(num_hypos)
|
||||
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
gen_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_output_tokens[idx].tolist()
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
||||
k = tuple(ngram[:-1])
|
||||
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
|
||||
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
|
||||
return gen_ngrams[hypo_idx].get(ngram_index, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
@@ -171,7 +171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
|
||||
if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None:
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
@@ -558,7 +558,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
model.tie_weights() # make sure word embedding weights are still tied if needed
|
||||
|
||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||
@@ -574,15 +573,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
return True
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
has_output_past = getattr(self.config, "output_past", False)
|
||||
mem_len = getattr(self.config, "mem_len", 0)
|
||||
if len(outputs) <= 1:
|
||||
return False
|
||||
if mem_len > 0 or has_output_past:
|
||||
return True
|
||||
return False
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
||||
for i in range(batch_size * num_beams):
|
||||
for previous_token in set(prev_output_tokens[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if lprobs[i, previous_token] < 0:
|
||||
lprobs[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
lprobs[i, previous_token] /= repetition_penalty
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
@@ -761,7 +769,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
|
||||
input_ids = input_ids.contiguous().view(
|
||||
batch_size * num_return_sequences, cur_len
|
||||
) # (batch_size * num_return_sequences, cur_len)
|
||||
) # shape: (batch_size * num_return_sequences, cur_len)
|
||||
effective_batch_size = batch_size * num_return_sequences
|
||||
else:
|
||||
effective_batch_size = batch_size
|
||||
@@ -822,9 +830,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
@@ -834,13 +842,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if next_token_logits[i, previous_token] < 0:
|
||||
next_token_logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
next_token_logits[i, previous_token] /= repetition_penalty
|
||||
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -911,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
# Expand input to num beams
|
||||
# assert input_ids.shape == (batch_size * num_beams, cur_len)
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
|
||||
@@ -941,13 +944,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size * num_beams):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if scores[i, previous_token] < 0:
|
||||
scores[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
scores[i, previous_token] /= repetition_penalty
|
||||
self.enforce_repetition_penalty_(scores, batch_size, num_beams, input_ids, repetition_penalty)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -1039,16 +1036,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
@@ -1096,6 +1084,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
@@ -1164,17 +1166,22 @@ class BeamHypotheses(object):
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
||||
if cur_len is None:
|
||||
cur_len = self.max_length
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
|
||||
@@ -19,11 +19,7 @@ from .tokenization_roberta import RobertaTokenizer
|
||||
# vocab and merges same as roberta
|
||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_bart_models = [
|
||||
"bart-large",
|
||||
"bart-large-mnli",
|
||||
# "bart-large-cnn"
|
||||
]
|
||||
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
|
||||
|
||||
|
||||
class BartTokenizer(RobertaTokenizer):
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user