Add beam search
This commit is contained in:
committed by
Julien Chaumond
parent
1c71ecc880
commit
9660ba1cbd
@@ -1,502 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2019 The HuggingFace Inc. team.
|
|
||||||
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Finetuning seq2seq models for sequence generation."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
import torch
|
|
||||||
from torch.optim import Adam
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
BertForMaskedLM,
|
|
||||||
BertConfig,
|
|
||||||
PreTrainedEncoderDecoder,
|
|
||||||
Model2Model,
|
|
||||||
)
|
|
||||||
|
|
||||||
from utils_summarization import (
|
|
||||||
CNNDailyMailDataset,
|
|
||||||
encode_for_summarization,
|
|
||||||
fit_to_block_size,
|
|
||||||
build_lm_labels,
|
|
||||||
build_mask,
|
|
||||||
compute_token_type_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Load dataset
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer):
|
|
||||||
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def collate(data, tokenizer, block_size):
|
|
||||||
""" List of tuple as an input. """
|
|
||||||
# remove the files with empty an story/summary, encode and fit to block
|
|
||||||
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
|
|
||||||
data = [
|
|
||||||
encode_for_summarization(story, summary, tokenizer) for story, summary in data
|
|
||||||
]
|
|
||||||
data = [
|
|
||||||
(
|
|
||||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
|
|
||||||
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
|
|
||||||
)
|
|
||||||
for story, summary in data
|
|
||||||
]
|
|
||||||
|
|
||||||
stories = torch.tensor([story for story, summary in data])
|
|
||||||
summaries = torch.tensor([summary for story, summary in data])
|
|
||||||
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
|
||||||
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
|
||||||
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
|
|
||||||
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
return (
|
|
||||||
stories,
|
|
||||||
summaries,
|
|
||||||
encoder_token_type_ids,
|
|
||||||
encoder_mask,
|
|
||||||
decoder_mask,
|
|
||||||
lm_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# Optimizers
|
|
||||||
# ----------
|
|
||||||
|
|
||||||
|
|
||||||
class BertSumOptimizer(object):
|
|
||||||
""" Specific optimizer for BertSum.
|
|
||||||
|
|
||||||
As described in [1], the authors fine-tune BertSum for abstractive
|
|
||||||
summarization using two Adam Optimizers with different warm-up steps and
|
|
||||||
learning rate. They also use a custom learning rate scheduler.
|
|
||||||
|
|
||||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
|
||||||
arXiv preprint arXiv:1908.08345 (2019).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
|
|
||||||
self.encoder = model.encoder
|
|
||||||
self.decoder = model.decoder
|
|
||||||
self.lr = lr
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
|
|
||||||
self.optimizers = {
|
|
||||||
"encoder": Adam(
|
|
||||||
model.encoder.parameters(),
|
|
||||||
lr=lr["encoder"],
|
|
||||||
betas=(beta_1, beta_2),
|
|
||||||
eps=eps,
|
|
||||||
),
|
|
||||||
"decoder": Adam(
|
|
||||||
model.decoder.parameters(),
|
|
||||||
lr=lr["decoder"],
|
|
||||||
betas=(beta_1, beta_2),
|
|
||||||
eps=eps,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
self._step = 0
|
|
||||||
|
|
||||||
def _update_rate(self, stack):
|
|
||||||
return self.lr[stack] * min(
|
|
||||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
|
||||||
self.optimizer_decoder.zero_grad()
|
|
||||||
self.optimizer_encoder.zero_grad()
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
self._step += 1
|
|
||||||
for stack, optimizer in self.optimizers.items():
|
|
||||||
new_rate = self._update_rate(stack)
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group["lr"] = new_rate
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Train
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
|
|
||||||
def train(args, model, tokenizer):
|
|
||||||
""" Fine-tune the pretrained model on the corpus. """
|
|
||||||
set_seed(args)
|
|
||||||
|
|
||||||
# Load the data
|
|
||||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
|
||||||
train_sampler = RandomSampler(train_dataset)
|
|
||||||
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
sampler=train_sampler,
|
|
||||||
batch_size=args.train_batch_size,
|
|
||||||
collate_fn=model_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training schedule
|
|
||||||
if args.max_steps > 0:
|
|
||||||
t_total = args.max_steps
|
|
||||||
args.num_train_epochs = t_total // (
|
|
||||||
len(train_dataloader) // args.gradient_accumulation_steps + 1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
t_total = (
|
|
||||||
len(train_dataloader)
|
|
||||||
// args.gradient_accumulation_steps
|
|
||||||
* args.num_train_epochs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare the optimizer
|
|
||||||
lr = {"encoder": 0.002, "decoder": 0.2}
|
|
||||||
warmup_steps = {"encoder": 20000, "decoder": 10000}
|
|
||||||
optimizer = BertSumOptimizer(model, lr, warmup_steps)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
logger.info("***** Running training *****")
|
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
|
||||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
|
||||||
logger.info(
|
|
||||||
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
|
||||||
args.train_batch_size * args.gradient_accumulation_steps
|
|
||||||
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
|
||||||
)
|
|
||||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
|
||||||
|
|
||||||
model.zero_grad()
|
|
||||||
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
|
|
||||||
|
|
||||||
global_step = 0
|
|
||||||
tr_loss = 0.0
|
|
||||||
for _ in train_iterator:
|
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
|
|
||||||
for step, batch in enumerate(epoch_iterator):
|
|
||||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
|
||||||
|
|
||||||
source = source.to(args.device)
|
|
||||||
target = target.to(args.device)
|
|
||||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
|
||||||
encoder_mask = encoder_mask.to(args.device)
|
|
||||||
decoder_mask = decoder_mask.to(args.device)
|
|
||||||
lm_labels = lm_labels.to(args.device)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
outputs = model(
|
|
||||||
source,
|
|
||||||
target,
|
|
||||||
encoder_token_type_ids=encoder_token_type_ids,
|
|
||||||
encoder_attention_mask=encoder_mask,
|
|
||||||
decoder_attention_mask=decoder_mask,
|
|
||||||
decoder_lm_labels=lm_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = outputs[0]
|
|
||||||
print(loss)
|
|
||||||
if args.gradient_accumulation_steps > 1:
|
|
||||||
loss /= args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
tr_loss += loss.item()
|
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
||||||
optimizer.step()
|
|
||||||
model.zero_grad()
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
|
||||||
epoch_iterator.close()
|
|
||||||
break
|
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
|
||||||
train_iterator.close()
|
|
||||||
break
|
|
||||||
|
|
||||||
return global_step, tr_loss / global_step
|
|
||||||
|
|
||||||
|
|
||||||
# ------------
|
|
||||||
# Train
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
|
||||||
set_seed(args)
|
|
||||||
|
|
||||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
|
||||||
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
|
||||||
eval_sampler = SequentialSampler(eval_dataset)
|
|
||||||
eval_dataloader = DataLoader(
|
|
||||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-gpu evaluate
|
|
||||||
if args.n_gpu > 1:
|
|
||||||
model = torch.nn.DataParallel(model)
|
|
||||||
|
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
|
||||||
logger.info(" Num examples = %d", len(eval_dataset))
|
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
|
||||||
eval_loss = 0.0
|
|
||||||
nb_eval_steps = 0
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
|
||||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
|
||||||
|
|
||||||
source = source.to(args.device)
|
|
||||||
target = target.to(args.device)
|
|
||||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
|
||||||
encoder_mask = encoder_mask.to(args.device)
|
|
||||||
decoder_mask = decoder_mask.to(args.device)
|
|
||||||
lm_labels = lm_labels.to(args.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(
|
|
||||||
source,
|
|
||||||
target,
|
|
||||||
encoder_token_type_ids=encoder_token_type_ids,
|
|
||||||
encoder_attention_mask=encoder_mask,
|
|
||||||
decoder_attention_mask=decoder_mask,
|
|
||||||
decoder_lm_labels=lm_labels,
|
|
||||||
)
|
|
||||||
lm_loss = outputs[0]
|
|
||||||
eval_loss += lm_loss.mean().item()
|
|
||||||
nb_eval_steps += 1
|
|
||||||
|
|
||||||
eval_loss = eval_loss / nb_eval_steps
|
|
||||||
perplexity = torch.exp(torch.tensor(eval_loss))
|
|
||||||
|
|
||||||
result = {"perplexity": perplexity}
|
|
||||||
|
|
||||||
# Save the evaluation's results
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
|
||||||
if not os.path.exists(args.output_dir):
|
|
||||||
os.makedirs(args.output_dir)
|
|
||||||
|
|
||||||
with open(output_eval_file, "w") as writer:
|
|
||||||
logger.info("***** Eval results {} *****".format(prefix))
|
|
||||||
for key in sorted(result.keys()):
|
|
||||||
logger.info(" %s = %s", key, str(result[key]))
|
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def save_model_checkpoints(args, model, tokenizer):
|
|
||||||
if not os.path.exists(args.output_dir):
|
|
||||||
os.makedirs(args.output_dir)
|
|
||||||
|
|
||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
|
||||||
|
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
|
||||||
# They can then be reloaded using `from_pretrained()`
|
|
||||||
model_to_save = (
|
|
||||||
model.module if hasattr(model, "module") else model
|
|
||||||
) # Take care of distributed/parallel training
|
|
||||||
model_to_save.save_pretrained(args.output_dir, model_type='bert')
|
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
|
||||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# Required parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The input training data file (a text file).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_dir",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optional parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_accumulation_steps",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--do_evaluate",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Run model evaluation on out-of-sample data.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--do_overwrite_output_dir",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to overwrite the output dir.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name_or_path",
|
|
||||||
default="bert-base-cased",
|
|
||||||
type=str,
|
|
||||||
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_type",
|
|
||||||
default="bert",
|
|
||||||
type=str,
|
|
||||||
help="The decoder architecture to be fine-tuned.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_steps",
|
|
||||||
default=-1,
|
|
||||||
type=int,
|
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_train_epochs",
|
|
||||||
default=10,
|
|
||||||
type=int,
|
|
||||||
help="Total number of training epochs to perform.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--per_gpu_train_batch_size",
|
|
||||||
default=4,
|
|
||||||
type=int,
|
|
||||||
help="Batch size per GPU/CPU for training.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", default=42, type=int)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if (
|
|
||||||
os.path.exists(args.output_dir)
|
|
||||||
and os.listdir(args.output_dir)
|
|
||||||
and args.do_train
|
|
||||||
and not args.do_overwrite_output_dir
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
|
|
||||||
args.output_dir
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up training device
|
|
||||||
if args.to_cpu or not torch.cuda.is_available():
|
|
||||||
args.device = torch.device("cpu")
|
|
||||||
args.n_gpu = 0
|
|
||||||
else:
|
|
||||||
args.device = torch.device("cuda")
|
|
||||||
args.n_gpu = torch.cuda.device_count()
|
|
||||||
|
|
||||||
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
||||||
config = BertConfig.from_pretrained(args.model_name_or_path)
|
|
||||||
decoder_model = BertForMaskedLM(config)
|
|
||||||
model = Model2Model.from_pretrained(
|
|
||||||
args.model_name_or_path, decoder_model=decoder_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
|
||||||
0,
|
|
||||||
args.device,
|
|
||||||
args.n_gpu,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
|
||||||
|
|
||||||
# Train the model
|
|
||||||
model.to(args.device)
|
|
||||||
if args.do_train:
|
|
||||||
try:
|
|
||||||
global_step, tr_loss = train(args, model, tokenizer)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
|
|
||||||
if response.lower() in ["", "y", "yes"]:
|
|
||||||
save_model_checkpoints(args, model, tokenizer)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
|
||||||
save_model_checkpoints(args, model, tokenizer)
|
|
||||||
|
|
||||||
# Evaluate the model
|
|
||||||
results = {}
|
|
||||||
if args.do_evaluate:
|
|
||||||
checkpoints = [args.output_dir]
|
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
|
||||||
for checkpoint in checkpoints:
|
|
||||||
encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
|
|
||||||
decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
|
|
||||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
|
||||||
encoder_checkpoint, decoder_checkpoint
|
|
||||||
)
|
|
||||||
model.to(args.device)
|
|
||||||
print("model loaded")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset):
|
|||||||
[2] https://github.com/abisee/cnn-dailymail/
|
[2] https://github.com/abisee/cnn-dailymail/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, prefix="train", data_dir=""):
|
def __init__(self, data_dir="", prefix="train"):
|
||||||
assert os.path.isdir(data_dir)
|
assert os.path.isdir(data_dir)
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
# We initialize the class by listing all the files that contain
|
# We initialize the class by listing all the files that contain
|
||||||
# stories and summaries. Files are not read in memory given
|
# stories and summaries. Files are not read in memory given
|
||||||
@@ -104,31 +103,30 @@ def _add_missing_period(line):
|
|||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
|
||||||
def fit_to_block_size(sequence, block_size, pad_token):
|
def fit_to_block_size(sequence, block_size, pad_token_id):
|
||||||
""" Adapt the source and target sequences' lengths to the block size.
|
""" Adapt the source and target sequences' lengths to the block size.
|
||||||
If the sequence is shorter than the block size we pad it with -1 ids
|
If the sequence is shorter we append padding token to the right of the sequence.
|
||||||
which correspond to padding tokens.
|
|
||||||
"""
|
"""
|
||||||
if len(sequence) > block_size:
|
if len(sequence) > block_size:
|
||||||
return sequence[:block_size]
|
return sequence[:block_size]
|
||||||
else:
|
else:
|
||||||
sequence.extend([pad_token] * (block_size - len(sequence)))
|
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
def build_lm_labels(sequence, pad_token):
|
def build_lm_labels(sequence, pad_token_id):
|
||||||
""" Padding token, encoded as 0, are represented by the value -1 so they
|
""" Padding token are replaced by the value -1 so they
|
||||||
are not taken into account in the loss computation. """
|
are not taken into account in the loss computation. """
|
||||||
padded = sequence.clone()
|
padded = sequence.clone()
|
||||||
padded[padded == pad_token] = -1
|
padded[padded == pad_token_id] = -1
|
||||||
return padded
|
return padded
|
||||||
|
|
||||||
|
|
||||||
def build_mask(sequence, pad_token):
|
def build_mask(sequence, pad_token_id):
|
||||||
""" Builds the mask. The attention mechanism will only attend to positions
|
""" Builds the mask. The attention mechanism will only attend to positions
|
||||||
with value 1. """
|
with value 1. """
|
||||||
mask = torch.ones_like(sequence)
|
mask = torch.ones_like(sequence)
|
||||||
idx_pad_tokens = sequence == pad_token
|
idx_pad_tokens = sequence == pad_token_id
|
||||||
mask[idx_pad_tokens] = 0
|
mask[idx_pad_tokens] = 0
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|||||||
1
transformers/generate/__init__.py
Normal file
1
transformers/generate/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .beam_search import BeamSearch
|
||||||
358
transformers/generate/beam_search.py
Normal file
358
transformers/generate/beam_search.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# MIT License
|
||||||
|
|
||||||
|
# Copyright (c) 2017-Present OpenNMT
|
||||||
|
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
# this software and associated documentation files (the "Software"), to deal in
|
||||||
|
# the Software without restriction, including without limitation the rights to
|
||||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
||||||
|
# of the Software, and to permit persons to whom the Software is furnished to do
|
||||||
|
# so, subject to the following conditions:
|
||||||
|
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
"""
|
||||||
|
Use Beam Search to generate sequences using encoder-decoder models.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearch(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
beam_size,
|
||||||
|
min_length,
|
||||||
|
max_length,
|
||||||
|
batch_size=1,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Inputs:
|
||||||
|
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
||||||
|
The pretrained encoder-decoder model that will be used to generate the sequences.
|
||||||
|
**tokenizer**: instance of ``transformers.PreTrainedTokenizer``
|
||||||
|
The pretrained tokenizer associated to the model used in the encoder-decoder. We only
|
||||||
|
support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer
|
||||||
|
needs to be initialized or this function will raise and exception.
|
||||||
|
**batch_size**: (`optional`) int
|
||||||
|
Batch size of the inputs. The value is set automatically when calling `forward`.
|
||||||
|
**beam_size**: int
|
||||||
|
Number of beams that are used for each element on the batch.
|
||||||
|
**min_length**: int
|
||||||
|
Minimum number of steps performed by the beam search before terminating.
|
||||||
|
**max_length**: int
|
||||||
|
Maximum number of steps performed by the beam search. Any beam that has not finished
|
||||||
|
will return its current solution with the highest probability. The sequence that is
|
||||||
|
returned has a length of max_length-1 to account for the end token that is subsequently added.
|
||||||
|
**alpha**: float
|
||||||
|
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
|
||||||
|
**block_repeating_trigrams**: bool
|
||||||
|
Whether to block sequences that have repeating 3-grams.
|
||||||
|
"""
|
||||||
|
super(BeamSearch, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.bos_token_id = tokenizer.bos_token_id
|
||||||
|
self.eos_token_id = tokenizer.eos_token_id
|
||||||
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
self.block_repeating_trigram = block_repeating_trigrams
|
||||||
|
self.apply_length_penalty = False if alpha == 0 else True
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
self._init_beam_state(batch_size)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
try:
|
||||||
|
return self.growing_beams.size(1)
|
||||||
|
except NameError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _init_beam_state(self, batch_size):
|
||||||
|
""" (re-)Initialize the state of the beams. """
|
||||||
|
self.hypotheses = [[] for _ in range(batch_size)]
|
||||||
|
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||||
|
self.beam_offset = torch.arange(
|
||||||
|
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
|
||||||
|
)
|
||||||
|
self.growing_beams = torch.full(
|
||||||
|
(batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long
|
||||||
|
)
|
||||||
|
self.topk_log_probabilities = torch.tensor(
|
||||||
|
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
|
||||||
|
).repeat(batch_size)
|
||||||
|
self.results = {
|
||||||
|
"predictions": [[] for _ in range(batch_size)],
|
||||||
|
"scores": [[] for _ in range(batch_size)],
|
||||||
|
}
|
||||||
|
self._step = 0
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
def forward(self, encoder_input_ids, **model_kwargs):
|
||||||
|
""" Generate a sequence using Beam Search. """
|
||||||
|
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||||
|
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||||
|
# that apply to the model as whole.
|
||||||
|
# We let the specific kwargs override the common ones in case of conflict.
|
||||||
|
kwargs_common = {
|
||||||
|
argument: value
|
||||||
|
for argument, value in model_kwargs.items()
|
||||||
|
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
kwargs_decoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder.update(
|
||||||
|
{
|
||||||
|
argument[len("encoder_") :]: value
|
||||||
|
for argument, value in model_kwargs.items()
|
||||||
|
if argument.startswith("encoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs_decoder.update(
|
||||||
|
{
|
||||||
|
argument[len("decoder_") :]: value
|
||||||
|
for argument, value in model_kwargs.items()
|
||||||
|
if argument.startswith("decoder_")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# forward pass on the encoder
|
||||||
|
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
|
||||||
|
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||||
|
encoder_outputs, self.beam_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# grow the beam by generating sequences in an autoregressive way
|
||||||
|
batch_size = encoder_input_ids.size(0)
|
||||||
|
self._init_beam_state(batch_size)
|
||||||
|
for step in range(self.max_length):
|
||||||
|
# prepare the decoder input
|
||||||
|
decoder_input = fit_to_block_size(
|
||||||
|
self.growing_beams, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
kwargs_decoder["decoder_lm_labels"] = build_lm_labels(
|
||||||
|
decoder_input, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
kwargs_decoder["decoder_attention_mask"] = build_mask(
|
||||||
|
decoder_input, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.model.decoder(decoder_input, kwargs_decoder)
|
||||||
|
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
||||||
|
surviving_beams_rows = self.grow(log_probabilities)
|
||||||
|
if self.is_done:
|
||||||
|
break
|
||||||
|
|
||||||
|
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
||||||
|
"encoder_hidden_states"
|
||||||
|
].index_select(0, surviving_beams_rows)
|
||||||
|
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
|
||||||
|
"encoder_attention_mask"
|
||||||
|
].index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def grow(self, log_probabilities):
|
||||||
|
""" Grow the beams by one step. """
|
||||||
|
self._step += 1
|
||||||
|
|
||||||
|
# The number of beams changes as some beams finish so we define _B
|
||||||
|
vocab_size = log_probabilities.size(-1)
|
||||||
|
_B = log_probabilities.size(0) // self.beam_size
|
||||||
|
|
||||||
|
# Multiply each beam probability with the probability of the
|
||||||
|
# next token (conditioned on the words in the beam).
|
||||||
|
log_probabilities += self.topk_log_probabilities.view(-1, 1)
|
||||||
|
|
||||||
|
self._enforce_min_length(log_probabilities)
|
||||||
|
if self.block_repeating_trigram:
|
||||||
|
self._remove_beams_with_repeating_trigrams(log_probabilities, _B)
|
||||||
|
|
||||||
|
# Find the `beam_size` (previous_beam + token) combinations with
|
||||||
|
# the highest score
|
||||||
|
topk_log_probabilities, topk_ids = torch.topk(
|
||||||
|
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||||
|
# that will be added if the beam ends.
|
||||||
|
topk_scores = topk_log_probabilities
|
||||||
|
if self.apply_length_penalty:
|
||||||
|
topk_scores /= self._length_penalty()
|
||||||
|
|
||||||
|
# Retrieve the corresponding respective beam and token id
|
||||||
|
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||||
|
topk_beam_ids = topk_ids.div(vocab_size)
|
||||||
|
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||||
|
|
||||||
|
# Retrieve the row index of the surviving beams in the original
|
||||||
|
# view of the log_probabilities tensor
|
||||||
|
surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1)
|
||||||
|
surviving_beams_rows = surviving_beams_per_batch.view(-1)
|
||||||
|
|
||||||
|
# Append the last predictions
|
||||||
|
self.growing_beams = torch.cat(
|
||||||
|
[
|
||||||
|
self.growing_beams.index_select(0, surviving_beams_rows),
|
||||||
|
topk_token_ids.view(-1, 1),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if any of the beam searches has ended during this
|
||||||
|
# growth step. Also if top beam (most probable) has ended
|
||||||
|
# for one element of the batch.
|
||||||
|
is_finished = topk_token_ids.eq(self.eos_token_id)
|
||||||
|
self._enforce_max_length(is_finished)
|
||||||
|
if is_finished.any():
|
||||||
|
non_finished = self._cut_finished(is_finished, topk_scores)
|
||||||
|
self.batch_offset = self.batch_offset.index_select(0, non_finished)
|
||||||
|
surviving_beams_per_batch = surviving_beams_per_batch.index_select(
|
||||||
|
0, non_finished
|
||||||
|
)
|
||||||
|
self.topk_log_probabilities = self.topk_log_probabilities.index_select(
|
||||||
|
0, non_finished
|
||||||
|
)
|
||||||
|
|
||||||
|
surviving_beams_rows = surviving_beams_per_batch.view(-1)
|
||||||
|
self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
|
return surviving_beams_rows
|
||||||
|
|
||||||
|
def _cut_finished(self, is_finished, topk_scores):
|
||||||
|
""" Save the finished searches and cut the correponding sequences off
|
||||||
|
the beams. """
|
||||||
|
is_top_beam_finished = is_finished[:, 0].eq(True)
|
||||||
|
|
||||||
|
# Save the finished searches
|
||||||
|
predictions = self.growing_beams.view(
|
||||||
|
-1, self.beam_size, self.growing_beams.size(1)
|
||||||
|
)
|
||||||
|
for i in range(is_finished.size(0)):
|
||||||
|
if is_top_beam_finished[i]:
|
||||||
|
is_finished[i].fill_(1)
|
||||||
|
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||||
|
|
||||||
|
# Store the finished beams as a (score, prediction) hypothesis.
|
||||||
|
b = self.batch_offset[i]
|
||||||
|
for j in finished_hyp:
|
||||||
|
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||||
|
|
||||||
|
# If the batch reached the end, save the best hypotheses
|
||||||
|
# in terms of length-penalized score.
|
||||||
|
if is_top_beam_finished[i]:
|
||||||
|
best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0])
|
||||||
|
self.results["scores"][b].append(best_score)
|
||||||
|
self.results["predictions"][b].append(best_prediction)
|
||||||
|
|
||||||
|
non_finished = is_top_beam_finished.eq(False).nonzero().view(-1)
|
||||||
|
if len(non_finished) == 0:
|
||||||
|
self.is_done = True
|
||||||
|
|
||||||
|
return non_finished
|
||||||
|
|
||||||
|
def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B):
|
||||||
|
if self._step + 1 > 3: # [BOS] does not count
|
||||||
|
for i in range(_B * self.beam_size):
|
||||||
|
tokens = self.growing_beams[i]
|
||||||
|
trigrams = [
|
||||||
|
(tokens[j - 1], tokens[j], tokens[j + 1])
|
||||||
|
for j in range(1, len(self) - 1)
|
||||||
|
]
|
||||||
|
last_trigram = tuple(trigrams[-1])
|
||||||
|
if last_trigram in trigrams[:-1]:
|
||||||
|
log_probabilities[i] = -1e20
|
||||||
|
|
||||||
|
def _enforce_min_length(self, log_probabilities):
|
||||||
|
if self._step < self.min_length:
|
||||||
|
log_probabilities[:, self.eos_token_id] = -1e20
|
||||||
|
|
||||||
|
def _enforce_max_length(self, is_finished):
|
||||||
|
# +1 because we will need to add an [EOS] token
|
||||||
|
if self._step + 1 == self.max_length:
|
||||||
|
is_finished.fill_(1)
|
||||||
|
|
||||||
|
def _length_penalty(self):
|
||||||
|
""" The calculation of the length penalty follows that of [1].
|
||||||
|
|
||||||
|
[1] Wu, Yonghui, et al. "Google's neural machine translation system:
|
||||||
|
Bridging the gap between human and machine translation." arXiv preprint
|
||||||
|
arXiv:1609.08144 (2016).
|
||||||
|
"""
|
||||||
|
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
|
||||||
|
|
||||||
|
|
||||||
|
def tile(x, count, dim=0):
|
||||||
|
"""
|
||||||
|
Tiles `x` along dimension `dim` `count` times.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>> ex = torch.tensor([1,2],[3,4])
|
||||||
|
>> tile(ex, 2, 0)
|
||||||
|
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
|
||||||
|
"""
|
||||||
|
perm = list(range(len(x.size())))
|
||||||
|
if dim != 0:
|
||||||
|
perm[0], perm[dim] = perm[dim], perm[0]
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
out_size = list(x.size())
|
||||||
|
out_size[0] *= count
|
||||||
|
batch = x.size(0)
|
||||||
|
x = (
|
||||||
|
x.view(batch, -1)
|
||||||
|
.transpose(0, 1)
|
||||||
|
.repeat(count, 1)
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(*out_size)
|
||||||
|
)
|
||||||
|
if dim != 0:
|
||||||
|
x = x.permute(perm).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fit_to_block_size(sequence, block_size, pad_token_id):
|
||||||
|
""" Adapt the source and target sequences' lengths to the block size.
|
||||||
|
If the sequence is shorter we append padding tokens to the right.
|
||||||
|
"""
|
||||||
|
if len(sequence) > block_size:
|
||||||
|
return sequence[:block_size]
|
||||||
|
else:
|
||||||
|
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
def build_lm_labels(sequence, pad_token_id):
|
||||||
|
""" Padding token, encoded as 0, are represented by the value -1 so they
|
||||||
|
are not taken into account in the loss computation. """
|
||||||
|
padded = sequence.clone()
|
||||||
|
padded[padded == pad_token_id] = -1
|
||||||
|
return padded
|
||||||
|
|
||||||
|
|
||||||
|
def build_mask(sequence, pad_token_id):
|
||||||
|
""" Builds the mask. The attention mechanism will only attend to positions
|
||||||
|
with value 1. """
|
||||||
|
mask = torch.ones_like(sequence)
|
||||||
|
idx_pad_tokens = sequence == pad_token_id
|
||||||
|
mask[idx_pad_tokens] = 0
|
||||||
|
return mask
|
||||||
@@ -1,271 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright (c) 2019 Yang Liu
|
|
||||||
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
# of this software and associated documentation files (the "Software"), to deal
|
|
||||||
# in the Software without restriction, including without limitation the rights
|
|
||||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
# copies of the Software, and to permit persons to whom the Software is
|
|
||||||
# furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
"""
|
|
||||||
A general wrapper around models with LM heads to generate sequences
|
|
||||||
using beam search.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBeamSearch(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
batch_size,
|
|
||||||
beam_size,
|
|
||||||
min_length,
|
|
||||||
max_length,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigram=True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Attributes:
|
|
||||||
mask_word_id: token id that corresponds to the mask
|
|
||||||
"""
|
|
||||||
super(TransformerBeamSearch, self).__init__()
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
self.start_token_id = tokenizer.start_token_id
|
|
||||||
self.end_token_id = tokenizer.end_token_id
|
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
|
||||||
|
|
||||||
self.beam_size = beam_size
|
|
||||||
self.min_length = min_length
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
self.block_repeating_trigram = block_repeating_trigram
|
|
||||||
self.apply_length_penalty = False if alpha == 0 else True
|
|
||||||
self.alpha = alpha
|
|
||||||
|
|
||||||
# State of the beam
|
|
||||||
self.hypotheses = [[] for _ in range(batch_size)]
|
|
||||||
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
|
|
||||||
self.beam_offset = torch.arange(
|
|
||||||
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
|
|
||||||
)
|
|
||||||
self.growing_beam = torch.full(
|
|
||||||
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
|
||||||
)
|
|
||||||
self.topk_log_probabilities = torch.tensor(
|
|
||||||
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
|
|
||||||
).repeat(batch_size)
|
|
||||||
self.results = {
|
|
||||||
"prediction": [[] for _ in batch_size],
|
|
||||||
"scores": [[] for _ in batch_size],
|
|
||||||
}
|
|
||||||
self._step = 0
|
|
||||||
self.is_done = False
|
|
||||||
|
|
||||||
def step(self, log_probabilities):
|
|
||||||
""" Grows the beam by one step. """
|
|
||||||
self._step += 1
|
|
||||||
|
|
||||||
# The batch size changes as some beams finish so we define _B
|
|
||||||
vocab_size = log_probabilities.size(-1)
|
|
||||||
_B = log_probabilities.size(0) // self.beam_size
|
|
||||||
|
|
||||||
# Multiply each beam probability with the probability of the
|
|
||||||
# next token (conditioned on the words in the beam).
|
|
||||||
log_probabilities += self.topk_log_probabilities.view(-1, 1)
|
|
||||||
|
|
||||||
self.enforce_min_length(log_probabilities)
|
|
||||||
if self.block_repeating_trigram:
|
|
||||||
self.remove_repeating_trigrams(log_probabilities, _B)
|
|
||||||
|
|
||||||
# Find the `beam_size` (previous_beam + token) combinations with
|
|
||||||
# the highest score
|
|
||||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
|
||||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
|
||||||
self.beam_size,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
|
||||||
# that will be added if the beam ends.
|
|
||||||
topk_scores = topk_log_probabilities / self.length_penalty()
|
|
||||||
|
|
||||||
# Retrieve the corresponding respective beam and token id
|
|
||||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
|
||||||
topk_beam_ids = topk_ids.div(vocab_size)
|
|
||||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
|
||||||
|
|
||||||
# Retrieve the row index of the surviving beams in the original
|
|
||||||
# view of the log_probabilities tensor
|
|
||||||
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append the last predictions
|
|
||||||
self.growing_beam = torch.cat(
|
|
||||||
[
|
|
||||||
self.growing_beam.index_select(0, surviving_beams_rows),
|
|
||||||
topk_token_ids.view(-1, 1),
|
|
||||||
],
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if any of the beam searches has ended during this
|
|
||||||
# growth step. Also if top beam (most probable) has ended
|
|
||||||
# for one element of the batch.
|
|
||||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
|
||||||
self.enforce_max_length()
|
|
||||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
|
||||||
|
|
||||||
# Save the finished searches
|
|
||||||
if is_finished.any():
|
|
||||||
predictions = self.growing_beam.view(
|
|
||||||
-1, self.beam_size, self.growing_beam.size(1)
|
|
||||||
)
|
|
||||||
for i in range(is_finished.size(0)):
|
|
||||||
if is_top_beam_finished[i]:
|
|
||||||
is_finished[i].fill_(1)
|
|
||||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
|
||||||
|
|
||||||
# Store finished hypotheses for this batch.
|
|
||||||
b = self.batch_offset[i]
|
|
||||||
for j in finished_hyp:
|
|
||||||
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
|
||||||
|
|
||||||
# If the batch reached the end, save the best hypotheses
|
|
||||||
# in terms of length-penalized score.
|
|
||||||
if is_top_beam_finished[i]:
|
|
||||||
best_hyp = sorted(
|
|
||||||
self.hypotheses[b], key=lambda x: x[0], reverse=True
|
|
||||||
)
|
|
||||||
best_score, best_prediction = best_hyp[0]
|
|
||||||
self.results["scores"][b].append(best_score)
|
|
||||||
self.results["predictions"][b].append(best_prediction)
|
|
||||||
|
|
||||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
|
||||||
if len(non_finished) == 0:
|
|
||||||
self.is_done = True
|
|
||||||
|
|
||||||
# Remove finished batches for the next step.
|
|
||||||
topk_log_probabilities = topk_log_probabilities.index_select(
|
|
||||||
0, non_finished
|
|
||||||
)
|
|
||||||
self.batch_offset = self.batch_offset.index_select(0, non_finished)
|
|
||||||
self.growing_beam = predictions.index_select(0, non_finished).view(
|
|
||||||
-1, self.growing_beam.size(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
|
||||||
|
|
||||||
return surviving_beams_rows
|
|
||||||
|
|
||||||
def forward(self, encoder_input_ids, **kwargs):
|
|
||||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
|
||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
|
||||||
# that apply to the model as whole.
|
|
||||||
# We let the specific kwargs override the common ones in case of conflict.
|
|
||||||
kwargs_encoder = {
|
|
||||||
argument[len("encoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("encoder_")
|
|
||||||
}
|
|
||||||
kwargs_decoder = {
|
|
||||||
argument[len("decoder_"):]: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
kwargs_common = {
|
|
||||||
argument: value
|
|
||||||
for argument, value in kwargs.items()
|
|
||||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
|
||||||
}
|
|
||||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
|
||||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
|
||||||
|
|
||||||
# forward pass on the encoder
|
|
||||||
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
|
|
||||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
|
||||||
encoder_outputs, self.beam_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# grow the beam by generating sequences in an autoregressive way
|
|
||||||
self.growing_beam = torch.full(
|
|
||||||
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
|
||||||
)
|
|
||||||
for step in range(self.max_length):
|
|
||||||
decoder_input = self.growing_beam[:, -1]
|
|
||||||
outputs = self.model.decoder(decoder_input, kwargs_decoder)
|
|
||||||
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
|
||||||
surviving_beams_rows = self.step(log_probabilities)
|
|
||||||
if self.is_done:
|
|
||||||
break
|
|
||||||
|
|
||||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
|
||||||
"encoder_hidden_states"
|
|
||||||
].index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
return self.results
|
|
||||||
|
|
||||||
def remove_repeating_trigrams(self, log_probabilities, _B):
|
|
||||||
if(self._step + 1 > 3):
|
|
||||||
for i in range(_B * self.beam_size):
|
|
||||||
tokens = [t for t in self.growing_beam[i]]
|
|
||||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
|
||||||
last_trigram = tuple(trigrams[-1])
|
|
||||||
if last_trigram in trigrams[:-1]:
|
|
||||||
log_probabilities[i] = -1e20
|
|
||||||
|
|
||||||
def enforce_min_length(self):
|
|
||||||
if self._step < self.min_length:
|
|
||||||
self.log_probabilities[self.end_token_id] = -1e20
|
|
||||||
|
|
||||||
def enforce_max_length(self):
|
|
||||||
if self._step + 1 == self.max_length:
|
|
||||||
self.is_finished.fill_(1)
|
|
||||||
|
|
||||||
def length_penalty(self):
|
|
||||||
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
|
|
||||||
|
|
||||||
|
|
||||||
def tile(x, count, dim=0):
|
|
||||||
"""
|
|
||||||
Tiles `x` along dimension `dim` `count` times.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>> ex = torch.tensor([1,2],[3,4])
|
|
||||||
>> tile(ex, 2, 0)
|
|
||||||
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
|
|
||||||
"""
|
|
||||||
perm = list(range(len(x.size())))
|
|
||||||
if dim != 0:
|
|
||||||
perm[0], perm[dim] = perm[dim], perm[0]
|
|
||||||
x = x.permute(perm).contiguous()
|
|
||||||
out_size = list(x.size())
|
|
||||||
out_size[0] *= count
|
|
||||||
batch = x.size(0)
|
|
||||||
x = (
|
|
||||||
x.view(batch, -1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.repeat(count, 1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.view(*out_size)
|
|
||||||
)
|
|
||||||
if dim != 0:
|
|
||||||
x = x.permute(perm).contiguous()
|
|
||||||
return x
|
|
||||||
226
transformers/tests/beam_search_tests.py
Normal file
226
transformers/tests/beam_search_tests.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.generate import BeamSearch
|
||||||
|
from transformers import PreTrainedEncoderDecoder
|
||||||
|
|
||||||
|
|
||||||
|
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
|
||||||
|
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchtest(unittest.TestCase):
|
||||||
|
def test_beam_search_encoder_decoder_integration(self):
|
||||||
|
""" We make sure that no internal change in the PreTrainedEncoderDecoder
|
||||||
|
class will break the integration with the beam search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = PreTrainedEncoderDecoder("encoder", "decoder")
|
||||||
|
tokenizer = StubTokenizer(0, 1, 2)
|
||||||
|
try:
|
||||||
|
_ = BeamSearch(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=1,
|
||||||
|
beam_size=1,
|
||||||
|
min_length=1,
|
||||||
|
max_length=1,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=False,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.")
|
||||||
|
|
||||||
|
def test_beam_search_min_length(self):
|
||||||
|
""" We keep predicting the end_token for the first beam and check that
|
||||||
|
it is not marked as finished until the beam has reached the minimum
|
||||||
|
length. """
|
||||||
|
eos_idx = 3
|
||||||
|
vocab_size = 10
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
beam_size = 2
|
||||||
|
min_length = 5
|
||||||
|
|
||||||
|
beam = BeamSearch(
|
||||||
|
model=StubTransformer("encoder", "decoder"),
|
||||||
|
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
|
||||||
|
batch_size=batch_size,
|
||||||
|
beam_size=beam_size,
|
||||||
|
min_length=5,
|
||||||
|
max_length=10,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# To test that the minimum length is correctly enforced we constantly
|
||||||
|
# assign the highest probability to the [EOS] token (and assign lower
|
||||||
|
# probabilities to some other tokens).
|
||||||
|
# Since BeamSearch will reset its probability to 1e-20 as long as
|
||||||
|
# min_length has not been reached, we need to reset the value between
|
||||||
|
# steps.
|
||||||
|
non_eos_idxs = [4, 5, 1, 8, 9]
|
||||||
|
score_distribution = torch.log_softmax(
|
||||||
|
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
||||||
|
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||||
|
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
||||||
|
log_probabilities[0, idx] = score
|
||||||
|
|
||||||
|
for step in range(1, min_length + 2):
|
||||||
|
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||||
|
|
||||||
|
# Beam #3 and #4 teminate at the first step since the probability
|
||||||
|
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
||||||
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
|
if step < min_length:
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
beam.growing_beams.numpy(),
|
||||||
|
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
|
||||||
|
)
|
||||||
|
elif step == min_length:
|
||||||
|
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
||||||
|
self.assertTrue(beam.is_done)
|
||||||
|
break
|
||||||
|
|
||||||
|
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
|
def test_beam_search_max_length(self):
|
||||||
|
""" We keep predicting the same non-EOS token until we reach the
|
||||||
|
maximum permitted length """
|
||||||
|
batch_size = 3
|
||||||
|
beam_size = 2
|
||||||
|
max_length = 5
|
||||||
|
vocab_size = 10
|
||||||
|
|
||||||
|
beam = BeamSearch(
|
||||||
|
model=StubTransformer("encoder", "decoder"),
|
||||||
|
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||||
|
batch_size=batch_size,
|
||||||
|
beam_size=beam_size,
|
||||||
|
min_length=2,
|
||||||
|
max_length=max_length,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
||||||
|
|
||||||
|
# To test that beam search enforces the max length constraint we
|
||||||
|
# keep giving the highest probability to a token that is not the
|
||||||
|
# [EOS] token.
|
||||||
|
# The beam search will stop at max_length-1, assuming that one would
|
||||||
|
# add the [EOS] token at the end of the returned sequence.
|
||||||
|
token_idxs = [3, 4, 5]
|
||||||
|
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
|
||||||
|
for idx, score in zip(token_idxs, score_distribution):
|
||||||
|
log_probabilities[:, idx] = score
|
||||||
|
|
||||||
|
for step in range(1, max_length + 2):
|
||||||
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
|
if step + 1 < max_length:
|
||||||
|
self.assertFalse(beam.is_done)
|
||||||
|
elif step + 1 == max_length: # Now [EOS] is the most probable token
|
||||||
|
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
||||||
|
self.assertTrue(beam.is_done)
|
||||||
|
break
|
||||||
|
|
||||||
|
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
|
def test_beam_search_block_repeating_trigrams(self):
|
||||||
|
""" We make sure that the beams that contain repeating trigrams are removed. """
|
||||||
|
batch_size = 3
|
||||||
|
beam_size = 2
|
||||||
|
max_length = 10
|
||||||
|
vocab_size = 10
|
||||||
|
|
||||||
|
beam = BeamSearch(
|
||||||
|
model=StubTransformer("encoder", "decoder"),
|
||||||
|
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||||
|
batch_size=batch_size,
|
||||||
|
beam_size=beam_size,
|
||||||
|
min_length=2,
|
||||||
|
max_length=max_length,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
||||||
|
|
||||||
|
# To test that BeamSearch enforces the 3-gram constraint we give the
|
||||||
|
# highest probably to the same tokens in a cyclic fashion and make sure
|
||||||
|
# they disappear once the cycle has completed.
|
||||||
|
token_idxs = [3, 4, 5]
|
||||||
|
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
|
||||||
|
for idx, score in zip(token_idxs, score_distribution):
|
||||||
|
log_probabilities[:, idx] = score
|
||||||
|
|
||||||
|
for step in range(1, max_length + 2):
|
||||||
|
# Rotate the probabilities at each step
|
||||||
|
for idx in token_idxs:
|
||||||
|
score = score_distribution[(idx + step) % 3]
|
||||||
|
log_probabilities[::beam_size, idx] = score
|
||||||
|
|
||||||
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
|
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
|
if step < 7:
|
||||||
|
self.assertFalse(
|
||||||
|
np.array_equal(
|
||||||
|
log_probabilities.numpy()[0, :],
|
||||||
|
np.array([-1e20] * vocab_size, dtype="float32"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if step == 7:
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
log_probabilities.numpy()[0, :],
|
||||||
|
np.array([-1e20] * vocab_size, dtype="float32"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_beam_search_example_for_one_step(self):
|
||||||
|
""" We test that the predictions for one step of growth are correct. """
|
||||||
|
batch_size = 2
|
||||||
|
beam_size = 2
|
||||||
|
max_length = 10
|
||||||
|
vocab_size = 5
|
||||||
|
|
||||||
|
beam = BeamSearch(
|
||||||
|
model=StubTransformer("encoder", "decoder"),
|
||||||
|
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||||
|
batch_size=batch_size,
|
||||||
|
beam_size=beam_size,
|
||||||
|
min_length=2,
|
||||||
|
max_length=max_length,
|
||||||
|
alpha=0,
|
||||||
|
block_repeating_trigrams=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
||||||
|
log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0)
|
||||||
|
log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0)
|
||||||
|
|
||||||
|
# First pass
|
||||||
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
|
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]])
|
||||||
|
)
|
||||||
|
self.assertFalse(beam.is_done)
|
||||||
|
|
||||||
|
# Second pass
|
||||||
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
|
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
beam.growing_beams.numpy(),
|
||||||
|
np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]),
|
||||||
|
)
|
||||||
|
self.assertFalse(beam.is_done)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__name__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user