diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py deleted file mode 100644 index 9c2c7769c9..0000000000 --- a/examples/run_summarization_finetuning.py +++ /dev/null @@ -1,502 +0,0 @@ -# coding=utf-8 -# Copyright 2019 The HuggingFace Inc. team. -# Copyright (c) 2019 The HuggingFace Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Finetuning seq2seq models for sequence generation.""" - -import argparse -import functools -import logging -import os -import random -import sys - -import numpy as np -from tqdm import tqdm, trange -import torch -from torch.optim import Adam -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler - -from transformers import ( - AutoTokenizer, - BertForMaskedLM, - BertConfig, - PreTrainedEncoderDecoder, - Model2Model, -) - -from utils_summarization import ( - CNNDailyMailDataset, - encode_for_summarization, - fit_to_block_size, - build_lm_labels, - build_mask, - compute_token_type_ids, -) - -logger = logging.getLogger(__name__) -logging.basicConfig(stream=sys.stdout, level=logging.INFO) - - -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - - -# ------------ -# Load dataset -# ------------ - - -def load_and_cache_examples(args, tokenizer): - dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir) - return dataset - - -def collate(data, tokenizer, block_size): - """ List of tuple as an input. """ - # remove the files with empty an story/summary, encode and fit to block - data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data) - data = [ - encode_for_summarization(story, summary, tokenizer) for story, summary in data - ] - data = [ - ( - fit_to_block_size(story, block_size, tokenizer.pad_token_id), - fit_to_block_size(summary, block_size, tokenizer.pad_token_id), - ) - for story, summary in data - ] - - stories = torch.tensor([story for story, summary in data]) - summaries = torch.tensor([summary for story, summary in data]) - encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id) - encoder_mask = build_mask(stories, tokenizer.pad_token_id) - decoder_mask = build_mask(summaries, tokenizer.pad_token_id) - lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id) - - return ( - stories, - summaries, - encoder_token_type_ids, - encoder_mask, - decoder_mask, - lm_labels, - ) - - -# ---------- -# Optimizers -# ---------- - - -class BertSumOptimizer(object): - """ Specific optimizer for BertSum. - - As described in [1], the authors fine-tune BertSum for abstractive - summarization using two Adam Optimizers with different warm-up steps and - learning rate. They also use a custom learning rate scheduler. - - [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." - arXiv preprint arXiv:1908.08345 (2019). - """ - - def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): - self.encoder = model.encoder - self.decoder = model.decoder - self.lr = lr - self.warmup_steps = warmup_steps - - self.optimizers = { - "encoder": Adam( - model.encoder.parameters(), - lr=lr["encoder"], - betas=(beta_1, beta_2), - eps=eps, - ), - "decoder": Adam( - model.decoder.parameters(), - lr=lr["decoder"], - betas=(beta_1, beta_2), - eps=eps, - ), - } - - self._step = 0 - - def _update_rate(self, stack): - return self.lr[stack] * min( - self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5) - ) - - def zero_grad(self): - self.optimizer_decoder.zero_grad() - self.optimizer_encoder.zero_grad() - - def step(self): - self._step += 1 - for stack, optimizer in self.optimizers.items(): - new_rate = self._update_rate(stack) - for param_group in optimizer.param_groups: - param_group["lr"] = new_rate - optimizer.step() - - -# ------------ -# Train -# ------------ - - -def train(args, model, tokenizer): - """ Fine-tune the pretrained model on the corpus. """ - set_seed(args) - - # Load the data - args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) - train_dataset = load_and_cache_examples(args, tokenizer) - train_sampler = RandomSampler(train_dataset) - model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512) - train_dataloader = DataLoader( - train_dataset, - sampler=train_sampler, - batch_size=args.train_batch_size, - collate_fn=model_collate_fn, - ) - - # Training schedule - if args.max_steps > 0: - t_total = args.max_steps - args.num_train_epochs = t_total // ( - len(train_dataloader) // args.gradient_accumulation_steps + 1 - ) - else: - t_total = ( - len(train_dataloader) - // args.gradient_accumulation_steps - * args.num_train_epochs - ) - - # Prepare the optimizer - lr = {"encoder": 0.002, "decoder": 0.2} - warmup_steps = {"encoder": 20000, "decoder": 10000} - optimizer = BertSumOptimizer(model, lr, warmup_steps) - - # Train - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataset)) - logger.info(" Num Epochs = %d", args.num_train_epochs) - logger.info( - " Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size - ) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - args.train_batch_size * args.gradient_accumulation_steps - # * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - model.zero_grad() - train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True) - - global_step = 0 - tr_loss = 0.0 - for _ in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) - for step, batch in enumerate(epoch_iterator): - source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch - - source = source.to(args.device) - target = target.to(args.device) - encoder_token_type_ids = encoder_token_type_ids.to(args.device) - encoder_mask = encoder_mask.to(args.device) - decoder_mask = decoder_mask.to(args.device) - lm_labels = lm_labels.to(args.device) - - model.train() - outputs = model( - source, - target, - encoder_token_type_ids=encoder_token_type_ids, - encoder_attention_mask=encoder_mask, - decoder_attention_mask=decoder_mask, - decoder_lm_labels=lm_labels, - ) - - loss = outputs[0] - print(loss) - if args.gradient_accumulation_steps > 1: - loss /= args.gradient_accumulation_steps - - loss.backward() - - tr_loss += loss.item() - if (step + 1) % args.gradient_accumulation_steps == 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) - optimizer.step() - model.zero_grad() - global_step += 1 - - if args.max_steps > 0 and global_step > args.max_steps: - epoch_iterator.close() - break - - if args.max_steps > 0 and global_step > args.max_steps: - train_iterator.close() - break - - return global_step, tr_loss / global_step - - -# ------------ -# Train -# ------------ - - -def evaluate(args, model, tokenizer, prefix=""): - set_seed(args) - - args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) - eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) - eval_sampler = SequentialSampler(eval_dataset) - eval_dataloader = DataLoader( - eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size - ) - - # multi-gpu evaluate - if args.n_gpu > 1: - model = torch.nn.DataParallel(model) - - logger.info("***** Running evaluation {} *****".format(prefix)) - logger.info(" Num examples = %d", len(eval_dataset)) - logger.info(" Batch size = %d", args.eval_batch_size) - eval_loss = 0.0 - nb_eval_steps = 0 - model.eval() - - for batch in tqdm(eval_dataloader, desc="Evaluating"): - source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch - - source = source.to(args.device) - target = target.to(args.device) - encoder_token_type_ids = encoder_token_type_ids.to(args.device) - encoder_mask = encoder_mask.to(args.device) - decoder_mask = decoder_mask.to(args.device) - lm_labels = lm_labels.to(args.device) - - with torch.no_grad(): - outputs = model( - source, - target, - encoder_token_type_ids=encoder_token_type_ids, - encoder_attention_mask=encoder_mask, - decoder_attention_mask=decoder_mask, - decoder_lm_labels=lm_labels, - ) - lm_loss = outputs[0] - eval_loss += lm_loss.mean().item() - nb_eval_steps += 1 - - eval_loss = eval_loss / nb_eval_steps - perplexity = torch.exp(torch.tensor(eval_loss)) - - result = {"perplexity": perplexity} - - # Save the evaluation's results - output_eval_file = os.path.join(args.output_dir, "eval_results.txt") - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results {} *****".format(prefix)) - for key in sorted(result.keys()): - logger.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) - - return result - - -def save_model_checkpoints(args, model, tokenizer): - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - logger.info("Saving model checkpoint to %s", args.output_dir) - - # Save a trained model, configuration and tokenizer using `save_pretrained()`. - # They can then be reloaded using `from_pretrained()` - model_to_save = ( - model.module if hasattr(model, "module") else model - ) # Take care of distributed/parallel training - model_to_save.save_pretrained(args.output_dir, model_type='bert') - tokenizer.save_pretrained(args.output_dir) - torch.save(args, os.path.join(args.output_dir, "training_arguments.bin")) - - -def main(): - parser = argparse.ArgumentParser() - - # Required parameters - parser.add_argument( - "--data_dir", - default=None, - type=str, - required=True, - help="The input training data file (a text file).", - ) - parser.add_argument( - "--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.", - ) - - # Optional parameters - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--do_evaluate", - type=bool, - default=False, - help="Run model evaluation on out-of-sample data.", - ) - parser.add_argument("--do_train", type=bool, default=False, help="Run training.") - parser.add_argument( - "--do_overwrite_output_dir", - type=bool, - default=False, - help="Whether to overwrite the output dir.", - ) - parser.add_argument( - "--model_name_or_path", - default="bert-base-cased", - type=str, - help="The model checkpoint to initialize the encoder and decoder's weights with.", - ) - parser.add_argument( - "--model_type", - default="bert", - type=str, - help="The decoder architecture to be fine-tuned.", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--max_steps", - default=-1, - type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.", - ) - parser.add_argument( - "--to_cpu", default=False, type=bool, help="Whether to force training on CPU." - ) - parser.add_argument( - "--num_train_epochs", - default=10, - type=int, - help="Total number of training epochs to perform.", - ) - parser.add_argument( - "--per_gpu_train_batch_size", - default=4, - type=int, - help="Batch size per GPU/CPU for training.", - ) - parser.add_argument("--seed", default=42, type=int) - args = parser.parse_args() - - if ( - os.path.exists(args.output_dir) - and os.listdir(args.output_dir) - and args.do_train - and not args.do_overwrite_output_dir - ): - raise ValueError( - "Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format( - args.output_dir - ) - ) - - # Set up training device - if args.to_cpu or not torch.cuda.is_available(): - args.device = torch.device("cpu") - args.n_gpu = 0 - else: - args.device = torch.device("cuda") - args.n_gpu = torch.cuda.device_count() - - # Load pretrained model and tokenizer. The decoder's weights are randomly initialized. - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) - config = BertConfig.from_pretrained(args.model_name_or_path) - decoder_model = BertForMaskedLM(config) - model = Model2Model.from_pretrained( - args.model_name_or_path, decoder_model=decoder_model - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - 0, - args.device, - args.n_gpu, - False, - False, - ) - - logger.info("Training/evaluation parameters %s", args) - - # Train the model - model.to(args.device) - if args.do_train: - try: - global_step, tr_loss = train(args, model, tokenizer) - except KeyboardInterrupt: - response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]") - if response.lower() in ["", "y", "yes"]: - save_model_checkpoints(args, model, tokenizer) - sys.exit(0) - - logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - save_model_checkpoints(args, model, tokenizer) - - # Evaluate the model - results = {} - if args.do_evaluate: - checkpoints = [args.output_dir] - logger.info("Evaluate the following checkpoints: %s", checkpoints) - for checkpoint in checkpoints: - encoder_checkpoint = os.path.join(checkpoint, "bert_encoder") - decoder_checkpoint = os.path.join(checkpoint, "bert_decoder") - model = PreTrainedEncoderDecoder.from_pretrained( - encoder_checkpoint, decoder_checkpoint - ) - model.to(args.device) - print("model loaded") - - return results - - -if __name__ == "__main__": - main() diff --git a/examples/utils_summarization.py b/examples/utils_summarization.py index 087c88bd4e..7cbd4cd61b 100644 --- a/examples/utils_summarization.py +++ b/examples/utils_summarization.py @@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset): [2] https://github.com/abisee/cnn-dailymail/ """ - def __init__(self, tokenizer, prefix="train", data_dir=""): + def __init__(self, data_dir="", prefix="train"): assert os.path.isdir(data_dir) - self.tokenizer = tokenizer # We initialize the class by listing all the files that contain # stories and summaries. Files are not read in memory given @@ -104,31 +103,30 @@ def _add_missing_period(line): # -------------------------- -def fit_to_block_size(sequence, block_size, pad_token): +def fit_to_block_size(sequence, block_size, pad_token_id): """ Adapt the source and target sequences' lengths to the block size. - If the sequence is shorter than the block size we pad it with -1 ids - which correspond to padding tokens. + If the sequence is shorter we append padding token to the right of the sequence. """ if len(sequence) > block_size: return sequence[:block_size] else: - sequence.extend([pad_token] * (block_size - len(sequence))) + sequence.extend([pad_token_id] * (block_size - len(sequence))) return sequence -def build_lm_labels(sequence, pad_token): - """ Padding token, encoded as 0, are represented by the value -1 so they +def build_lm_labels(sequence, pad_token_id): + """ Padding token are replaced by the value -1 so they are not taken into account in the loss computation. """ padded = sequence.clone() - padded[padded == pad_token] = -1 + padded[padded == pad_token_id] = -1 return padded -def build_mask(sequence, pad_token): +def build_mask(sequence, pad_token_id): """ Builds the mask. The attention mechanism will only attend to positions with value 1. """ mask = torch.ones_like(sequence) - idx_pad_tokens = sequence == pad_token + idx_pad_tokens = sequence == pad_token_id mask[idx_pad_tokens] = 0 return mask diff --git a/transformers/generate/__init__.py b/transformers/generate/__init__.py new file mode 100644 index 0000000000..21ac612155 --- /dev/null +++ b/transformers/generate/__init__.py @@ -0,0 +1 @@ +from .beam_search import BeamSearch diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py new file mode 100644 index 0000000000..09e340a150 --- /dev/null +++ b/transformers/generate/beam_search.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# MIT License + +# Copyright (c) 2017-Present OpenNMT + +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Use Beam Search to generate sequences using encoder-decoder models. +""" +import torch +from torch import nn + + +class BeamSearch(nn.Module): + def __init__( + self, + model, + tokenizer, + beam_size, + min_length, + max_length, + batch_size=1, + alpha=0, + block_repeating_trigrams=True, + ): + r""" + Inputs: + **model**: instance of ``transformers.PreTrainedEncoderDecoder`` + The pretrained encoder-decoder model that will be used to generate the sequences. + **tokenizer**: instance of ``transformers.PreTrainedTokenizer`` + The pretrained tokenizer associated to the model used in the encoder-decoder. We only + support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer + needs to be initialized or this function will raise and exception. + **batch_size**: (`optional`) int + Batch size of the inputs. The value is set automatically when calling `forward`. + **beam_size**: int + Number of beams that are used for each element on the batch. + **min_length**: int + Minimum number of steps performed by the beam search before terminating. + **max_length**: int + Maximum number of steps performed by the beam search. Any beam that has not finished + will return its current solution with the highest probability. The sequence that is + returned has a length of max_length-1 to account for the end token that is subsequently added. + **alpha**: float + Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details. + **block_repeating_trigrams**: bool + Whether to block sequences that have repeating 3-grams. + """ + super(BeamSearch, self).__init__() + self.model = model + self.tokenizer = tokenizer + + self.bos_token_id = tokenizer.bos_token_id + self.eos_token_id = tokenizer.eos_token_id + self.pad_token_id = tokenizer.pad_token_id + + self.batch_size = batch_size + self.beam_size = beam_size + self.min_length = min_length + self.max_length = max_length + + self.block_repeating_trigram = block_repeating_trigrams + self.apply_length_penalty = False if alpha == 0 else True + self.alpha = alpha + + self._init_beam_state(batch_size) + + def __len__(self): + try: + return self.growing_beams.size(1) + except NameError: + return 0 + + def _init_beam_state(self, batch_size): + """ (re-)Initialize the state of the beams. """ + self.hypotheses = [[] for _ in range(batch_size)] + self.batch_offset = torch.arange(batch_size, dtype=torch.long) + self.beam_offset = torch.arange( + 0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long + ) + self.growing_beams = torch.full( + (batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long + ) + self.topk_log_probabilities = torch.tensor( + [0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float + ).repeat(batch_size) + self.results = { + "predictions": [[] for _ in range(batch_size)], + "scores": [[] for _ in range(batch_size)], + } + self._step = 0 + self.is_done = False + + def forward(self, encoder_input_ids, **model_kwargs): + """ Generate a sequence using Beam Search. """ + # keyword arguments come in 3 flavors: encoder-specific (prefixed by + # `encoder_`), decoder-specific (prefixed by `decoder_`) and those + # that apply to the model as whole. + # We let the specific kwargs override the common ones in case of conflict. + kwargs_common = { + argument: value + for argument, value in model_kwargs.items() + if not argument.startswith("encoder_") and not argument.startswith("decoder_") + } + kwargs_decoder = kwargs_common.copy() + kwargs_encoder = kwargs_common.copy() + kwargs_encoder.update( + { + argument[len("encoder_") :]: value + for argument, value in model_kwargs.items() + if argument.startswith("encoder_") + } + ) + kwargs_decoder.update( + { + argument[len("decoder_") :]: value + for argument, value in model_kwargs.items() + if argument.startswith("decoder_") + } + ) + + # forward pass on the encoder + encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder) + kwargs_decoder["encoder_hidden_states"] = tile( + encoder_outputs, self.beam_size, dim=0 + ) + + # grow the beam by generating sequences in an autoregressive way + batch_size = encoder_input_ids.size(0) + self._init_beam_state(batch_size) + for step in range(self.max_length): + # prepare the decoder input + decoder_input = fit_to_block_size( + self.growing_beams, self.tokenizer.pad_token_id + ) + kwargs_decoder["decoder_lm_labels"] = build_lm_labels( + decoder_input, self.tokenizer.pad_token_id + ) + kwargs_decoder["decoder_attention_mask"] = build_mask( + decoder_input, self.tokenizer.pad_token_id + ) + + outputs = self.model.decoder(decoder_input, kwargs_decoder) + log_probabilities = torch.nn.functional.log_softmax(outputs[1]) + surviving_beams_rows = self.grow(log_probabilities) + if self.is_done: + break + + kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ + "encoder_hidden_states" + ].index_select(0, surviving_beams_rows) + kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[ + "encoder_attention_mask" + ].index_select(0, surviving_beams_rows) + + return self.results + + def grow(self, log_probabilities): + """ Grow the beams by one step. """ + self._step += 1 + + # The number of beams changes as some beams finish so we define _B + vocab_size = log_probabilities.size(-1) + _B = log_probabilities.size(0) // self.beam_size + + # Multiply each beam probability with the probability of the + # next token (conditioned on the words in the beam). + log_probabilities += self.topk_log_probabilities.view(-1, 1) + + self._enforce_min_length(log_probabilities) + if self.block_repeating_trigram: + self._remove_beams_with_repeating_trigrams(log_probabilities, _B) + + # Find the `beam_size` (previous_beam + token) combinations with + # the highest score + topk_log_probabilities, topk_ids = torch.topk( + log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1 + ) + + # Apply the length penalty. The +1 accounts for the [EOS] token + # that will be added if the beam ends. + topk_scores = topk_log_probabilities + if self.apply_length_penalty: + topk_scores /= self._length_penalty() + + # Retrieve the corresponding respective beam and token id + # topk_token_ids[i] will be added to topk_beam_ids[i] + topk_beam_ids = topk_ids.div(vocab_size) + topk_token_ids = topk_ids.fmod(vocab_size) + + # Retrieve the row index of the surviving beams in the original + # view of the log_probabilities tensor + surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1) + surviving_beams_rows = surviving_beams_per_batch.view(-1) + + # Append the last predictions + self.growing_beams = torch.cat( + [ + self.growing_beams.index_select(0, surviving_beams_rows), + topk_token_ids.view(-1, 1), + ], + 1, + ) + + # Check if any of the beam searches has ended during this + # growth step. Also if top beam (most probable) has ended + # for one element of the batch. + is_finished = topk_token_ids.eq(self.eos_token_id) + self._enforce_max_length(is_finished) + if is_finished.any(): + non_finished = self._cut_finished(is_finished, topk_scores) + self.batch_offset = self.batch_offset.index_select(0, non_finished) + surviving_beams_per_batch = surviving_beams_per_batch.index_select( + 0, non_finished + ) + self.topk_log_probabilities = self.topk_log_probabilities.index_select( + 0, non_finished + ) + + surviving_beams_rows = surviving_beams_per_batch.view(-1) + self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows) + + return surviving_beams_rows + + def _cut_finished(self, is_finished, topk_scores): + """ Save the finished searches and cut the correponding sequences off + the beams. """ + is_top_beam_finished = is_finished[:, 0].eq(True) + + # Save the finished searches + predictions = self.growing_beams.view( + -1, self.beam_size, self.growing_beams.size(1) + ) + for i in range(is_finished.size(0)): + if is_top_beam_finished[i]: + is_finished[i].fill_(1) + finished_hyp = is_finished[i].nonzero().view(-1) + + # Store the finished beams as a (score, prediction) hypothesis. + b = self.batch_offset[i] + for j in finished_hyp: + self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) + + # If the batch reached the end, save the best hypotheses + # in terms of length-penalized score. + if is_top_beam_finished[i]: + best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0]) + self.results["scores"][b].append(best_score) + self.results["predictions"][b].append(best_prediction) + + non_finished = is_top_beam_finished.eq(False).nonzero().view(-1) + if len(non_finished) == 0: + self.is_done = True + + return non_finished + + def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B): + if self._step + 1 > 3: # [BOS] does not count + for i in range(_B * self.beam_size): + tokens = self.growing_beams[i] + trigrams = [ + (tokens[j - 1], tokens[j], tokens[j + 1]) + for j in range(1, len(self) - 1) + ] + last_trigram = tuple(trigrams[-1]) + if last_trigram in trigrams[:-1]: + log_probabilities[i] = -1e20 + + def _enforce_min_length(self, log_probabilities): + if self._step < self.min_length: + log_probabilities[:, self.eos_token_id] = -1e20 + + def _enforce_max_length(self, is_finished): + # +1 because we will need to add an [EOS] token + if self._step + 1 == self.max_length: + is_finished.fill_(1) + + def _length_penalty(self): + """ The calculation of the length penalty follows that of [1]. + + [1] Wu, Yonghui, et al. "Google's neural machine translation system: + Bridging the gap between human and machine translation." arXiv preprint + arXiv:1609.08144 (2016). + """ + return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha + + +def tile(x, count, dim=0): + """ + Tiles `x` along dimension `dim` `count` times. + + Example: + >> ex = torch.tensor([1,2],[3,4]) + >> tile(ex, 2, 0) + torch.Tensor([[1,2],[1,2],[3,4],[3,4]]) + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = ( + x.view(batch, -1) + .transpose(0, 1) + .repeat(count, 1) + .transpose(0, 1) + .contiguous() + .view(*out_size) + ) + if dim != 0: + x = x.permute(perm).contiguous() + return x + + +def fit_to_block_size(sequence, block_size, pad_token_id): + """ Adapt the source and target sequences' lengths to the block size. + If the sequence is shorter we append padding tokens to the right. + """ + if len(sequence) > block_size: + return sequence[:block_size] + else: + sequence.extend([pad_token_id] * (block_size - len(sequence))) + return sequence + + +def build_lm_labels(sequence, pad_token_id): + """ Padding token, encoded as 0, are represented by the value -1 so they + are not taken into account in the loss computation. """ + padded = sequence.clone() + padded[padded == pad_token_id] = -1 + return padded + + +def build_mask(sequence, pad_token_id): + """ Builds the mask. The attention mechanism will only attend to positions + with value 1. """ + mask = torch.ones_like(sequence) + idx_pad_tokens = sequence == pad_token_id + mask[idx_pad_tokens] = 0 + return mask diff --git a/transformers/modeling_beam_search.py b/transformers/modeling_beam_search.py deleted file mode 100644 index 171dcb7247..0000000000 --- a/transformers/modeling_beam_search.py +++ /dev/null @@ -1,271 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2019 Yang Liu - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -A general wrapper around models with LM heads to generate sequences -using beam search. -""" -import torch -from torch import nn - - -class TransformerBeamSearch(nn.Module): - def __init__( - self, - model, - tokenizer, - batch_size, - beam_size, - min_length, - max_length, - alpha=0, - block_repeating_trigram=True, - ): - """ - Attributes: - mask_word_id: token id that corresponds to the mask - """ - super(TransformerBeamSearch, self).__init__() - self.model = model - self.tokenizer = tokenizer - - self.start_token_id = tokenizer.start_token_id - self.end_token_id = tokenizer.end_token_id - self.pad_token_id = tokenizer.pad_token_id - - self.beam_size = beam_size - self.min_length = min_length - self.max_length = max_length - - self.block_repeating_trigram = block_repeating_trigram - self.apply_length_penalty = False if alpha == 0 else True - self.alpha = alpha - - # State of the beam - self.hypotheses = [[] for _ in range(batch_size)] - self.batch_offset = torch.arange(batch_size, dtype=torch.long) - self.beam_offset = torch.arange( - 0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long - ) - self.growing_beam = torch.full( - (batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long - ) - self.topk_log_probabilities = torch.tensor( - [0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float - ).repeat(batch_size) - self.results = { - "prediction": [[] for _ in batch_size], - "scores": [[] for _ in batch_size], - } - self._step = 0 - self.is_done = False - - def step(self, log_probabilities): - """ Grows the beam by one step. """ - self._step += 1 - - # The batch size changes as some beams finish so we define _B - vocab_size = log_probabilities.size(-1) - _B = log_probabilities.size(0) // self.beam_size - - # Multiply each beam probability with the probability of the - # next token (conditioned on the words in the beam). - log_probabilities += self.topk_log_probabilities.view(-1, 1) - - self.enforce_min_length(log_probabilities) - if self.block_repeating_trigram: - self.remove_repeating_trigrams(log_probabilities, _B) - - # Find the `beam_size` (previous_beam + token) combinations with - # the highest score - topk_log_probabilities, topk_ids = log_probabilities.topk( - log_probabilities.view(_B, self.beam_size * vocab_size), - self.beam_size, - dim=1, - ) - - # Apply the length penalty. The +1 accounts for the [EOS] token - # that will be added if the beam ends. - topk_scores = topk_log_probabilities / self.length_penalty() - - # Retrieve the corresponding respective beam and token id - # topk_token_ids[i] will be added to topk_beam_ids[i] - topk_beam_ids = topk_ids.div(vocab_size) - topk_token_ids = topk_ids.fmod(vocab_size) - - # Retrieve the row index of the surviving beams in the original - # view of the log_probabilities tensor - surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view( - -1 - ) - - # Append the last predictions - self.growing_beam = torch.cat( - [ - self.growing_beam.index_select(0, surviving_beams_rows), - topk_token_ids.view(-1, 1), - ], - 1, - ) - - # Check if any of the beam searches has ended during this - # growth step. Also if top beam (most probable) has ended - # for one element of the batch. - is_finished = topk_token_ids.eq(self.end_token_id) - self.enforce_max_length() - is_top_beam_finished = is_finished[:, 0].eq(1) - - # Save the finished searches - if is_finished.any(): - predictions = self.growing_beam.view( - -1, self.beam_size, self.growing_beam.size(1) - ) - for i in range(is_finished.size(0)): - if is_top_beam_finished[i]: - is_finished[i].fill_(1) - finished_hyp = is_finished[i].nonzero().view(-1) - - # Store finished hypotheses for this batch. - b = self.batch_offset[i] - for j in finished_hyp: - self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) - - # If the batch reached the end, save the best hypotheses - # in terms of length-penalized score. - if is_top_beam_finished[i]: - best_hyp = sorted( - self.hypotheses[b], key=lambda x: x[0], reverse=True - ) - best_score, best_prediction = best_hyp[0] - self.results["scores"][b].append(best_score) - self.results["predictions"][b].append(best_prediction) - - non_finished = is_top_beam_finished.eq(0).nonzero().view(-1) - if len(non_finished) == 0: - self.is_done = True - - # Remove finished batches for the next step. - topk_log_probabilities = topk_log_probabilities.index_select( - 0, non_finished - ) - self.batch_offset = self.batch_offset.index_select(0, non_finished) - self.growing_beam = predictions.index_select(0, non_finished).view( - -1, self.growing_beam.size(-1) - ) - - surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished) - - return surviving_beams_rows - - def forward(self, encoder_input_ids, **kwargs): - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. - kwargs_encoder = { - argument[len("encoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("encoder_") - } - kwargs_decoder = { - argument[len("decoder_"):]: value - for argument, value in kwargs.items() - if argument.startswith("decoder_") - } - kwargs_common = { - argument: value - for argument, value in kwargs.items() - if not (argument.startswith("encoder_") or argument.startswith("decoder_")) - } - kwargs_decoder = dict(kwargs_common, **kwargs_decoder) - kwargs_encoder = dict(kwargs_common, **kwargs_encoder) - - # forward pass on the encoder - encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder) - kwargs_decoder["encoder_hidden_states"] = tile( - encoder_outputs, self.beam_size, dim=0 - ) - - # grow the beam by generating sequences in an autoregressive way - self.growing_beam = torch.full( - (self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long - ) - for step in range(self.max_length): - decoder_input = self.growing_beam[:, -1] - outputs = self.model.decoder(decoder_input, kwargs_decoder) - log_probabilities = torch.nn.functional.log_softmax(outputs[1]) - surviving_beams_rows = self.step(log_probabilities) - if self.is_done: - break - - kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ - "encoder_hidden_states" - ].index_select(0, surviving_beams_rows) - - return self.results - - def remove_repeating_trigrams(self, log_probabilities, _B): - if(self._step + 1 > 3): - for i in range(_B * self.beam_size): - tokens = [t for t in self.growing_beam[i]] - trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)] - last_trigram = tuple(trigrams[-1]) - if last_trigram in trigrams[:-1]: - log_probabilities[i] = -1e20 - - def enforce_min_length(self): - if self._step < self.min_length: - self.log_probabilities[self.end_token_id] = -1e20 - - def enforce_max_length(self): - if self._step + 1 == self.max_length: - self.is_finished.fill_(1) - - def length_penalty(self): - return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha - - -def tile(x, count, dim=0): - """ - Tiles `x` along dimension `dim` `count` times. - - Example: - >> ex = torch.tensor([1,2],[3,4]) - >> tile(ex, 2, 0) - torch.Tensor([[1,2],[1,2],[3,4],[3,4]]) - """ - perm = list(range(len(x.size()))) - if dim != 0: - perm[0], perm[dim] = perm[dim], perm[0] - x = x.permute(perm).contiguous() - out_size = list(x.size()) - out_size[0] *= count - batch = x.size(0) - x = ( - x.view(batch, -1) - .transpose(0, 1) - .repeat(count, 1) - .transpose(0, 1) - .contiguous() - .view(*out_size) - ) - if dim != 0: - x = x.permute(perm).contiguous() - return x diff --git a/transformers/tests/beam_search_tests.py b/transformers/tests/beam_search_tests.py new file mode 100644 index 0000000000..a92ebf3578 --- /dev/null +++ b/transformers/tests/beam_search_tests.py @@ -0,0 +1,226 @@ +from collections import namedtuple +import unittest + +import numpy as np +import torch + +from transformers.generate import BeamSearch +from transformers import PreTrainedEncoderDecoder + + +StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"]) +StubTransformer = namedtuple("Transformer", ["encoder", "decoder"]) + + +class BeamSearchtest(unittest.TestCase): + def test_beam_search_encoder_decoder_integration(self): + """ We make sure that no internal change in the PreTrainedEncoderDecoder + class will break the integration with the beam search. + """ + + model = PreTrainedEncoderDecoder("encoder", "decoder") + tokenizer = StubTokenizer(0, 1, 2) + try: + _ = BeamSearch( + model=model, + tokenizer=tokenizer, + batch_size=1, + beam_size=1, + min_length=1, + max_length=1, + alpha=0, + block_repeating_trigrams=False, + ) + except: + self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.") + + def test_beam_search_min_length(self): + """ We keep predicting the end_token for the first beam and check that + it is not marked as finished until the beam has reached the minimum + length. """ + eos_idx = 3 + vocab_size = 10 + + batch_size = 3 + beam_size = 2 + min_length = 5 + + beam = BeamSearch( + model=StubTransformer("encoder", "decoder"), + tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2), + batch_size=batch_size, + beam_size=beam_size, + min_length=5, + max_length=10, + alpha=0, + block_repeating_trigrams=False, + ) + + # To test that the minimum length is correctly enforced we constantly + # assign the highest probability to the [EOS] token (and assign lower + # probabilities to some other tokens). + # Since BeamSearch will reset its probability to 1e-20 as long as + # min_length has not been reached, we need to reset the value between + # steps. + non_eos_idxs = [4, 5, 1, 8, 9] + score_distribution = torch.log_softmax( + torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0 + ) + + log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) + log_probabilities[0, eos_idx] = score_distribution[0] + for idx, score in zip(non_eos_idxs, score_distribution[1:]): + log_probabilities[0, idx] = score + + for step in range(1, min_length + 2): + log_probabilities[0, eos_idx] = score_distribution[0] + + # Beam #3 and #4 teminate at the first step since the probability + # of the [EOS] token is -1e20 > -\infty so there are only two beams left. + surviving_beams_rows = beam.grow(log_probabilities) + if step < min_length: + np.testing.assert_array_equal( + beam.growing_beams.numpy(), + np.repeat(np.array([[0] + [4] * step]), 2, axis=0), + ) + elif step == min_length: + np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) + self.assertTrue(beam.is_done) + break + + log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) + + def test_beam_search_max_length(self): + """ We keep predicting the same non-EOS token until we reach the + maximum permitted length """ + batch_size = 3 + beam_size = 2 + max_length = 5 + vocab_size = 10 + + beam = BeamSearch( + model=StubTransformer("encoder", "decoder"), + tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + batch_size=batch_size, + beam_size=beam_size, + min_length=2, + max_length=max_length, + alpha=0, + block_repeating_trigrams=False, + ) + + log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) + + # To test that beam search enforces the max length constraint we + # keep giving the highest probability to a token that is not the + # [EOS] token. + # The beam search will stop at max_length-1, assuming that one would + # add the [EOS] token at the end of the returned sequence. + token_idxs = [3, 4, 5] + score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0) + for idx, score in zip(token_idxs, score_distribution): + log_probabilities[:, idx] = score + + for step in range(1, max_length + 2): + surviving_beams_rows = beam.grow(log_probabilities) + if step + 1 < max_length: + self.assertFalse(beam.is_done) + elif step + 1 == max_length: # Now [EOS] is the most probable token + np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) + self.assertTrue(beam.is_done) + break + + log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) + + def test_beam_search_block_repeating_trigrams(self): + """ We make sure that the beams that contain repeating trigrams are removed. """ + batch_size = 3 + beam_size = 2 + max_length = 10 + vocab_size = 10 + + beam = BeamSearch( + model=StubTransformer("encoder", "decoder"), + tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + batch_size=batch_size, + beam_size=beam_size, + min_length=2, + max_length=max_length, + alpha=0, + block_repeating_trigrams=True, + ) + + log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) + + # To test that BeamSearch enforces the 3-gram constraint we give the + # highest probably to the same tokens in a cyclic fashion and make sure + # they disappear once the cycle has completed. + token_idxs = [3, 4, 5] + score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0) + for idx, score in zip(token_idxs, score_distribution): + log_probabilities[:, idx] = score + + for step in range(1, max_length + 2): + # Rotate the probabilities at each step + for idx in token_idxs: + score = score_distribution[(idx + step) % 3] + log_probabilities[::beam_size, idx] = score + + surviving_beams_rows = beam.grow(log_probabilities) + log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) + + if step < 7: + self.assertFalse( + np.array_equal( + log_probabilities.numpy()[0, :], + np.array([-1e20] * vocab_size, dtype="float32"), + ) + ) + if step == 7: + np.testing.assert_array_equal( + log_probabilities.numpy()[0, :], + np.array([-1e20] * vocab_size, dtype="float32"), + ) + + def test_beam_search_example_for_one_step(self): + """ We test that the predictions for one step of growth are correct. """ + batch_size = 2 + beam_size = 2 + max_length = 10 + vocab_size = 5 + + beam = BeamSearch( + model=StubTransformer("encoder", "decoder"), + tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + batch_size=batch_size, + beam_size=beam_size, + min_length=2, + max_length=max_length, + alpha=0, + block_repeating_trigrams=False, + ) + + log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) + log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0) + log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0) + + # First pass + surviving_beams_rows = beam.grow(log_probabilities) + np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2])) + np.testing.assert_array_equal( + beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]]) + ) + self.assertFalse(beam.is_done) + + # Second pass + surviving_beams_rows = beam.grow(log_probabilities) + np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2])) + np.testing.assert_array_equal( + beam.growing_beams.numpy(), + np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]), + ) + self.assertFalse(beam.is_done) + + +if __name__ == "__name__": + unittest.main()