From ba089c780b918414bd8b669e1764fed728753edf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 6 Nov 2019 13:55:24 +0100 Subject: [PATCH] share pretrained embeddings --- examples/utils_summarization.py | 11 +--- requirements.txt | 4 +- transformers/generate/beam_search.py | 87 ++++++++++++++++++---------- 3 files changed, 60 insertions(+), 42 deletions(-) diff --git a/examples/utils_summarization.py b/examples/utils_summarization.py index 7cbd4cd61b..8e95a04e19 100644 --- a/examples/utils_summarization.py +++ b/examples/utils_summarization.py @@ -136,18 +136,11 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer): as specified in [1] by using `[SEP] [CLS]` tokens to separate sentences. """ - story_lines_token_ids = [ - tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line)) - for line in story_lines - ] - summary_lines_token_ids = [ - tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line)) - for line in summary_lines - ] - + story_lines_token_ids = [tokenizer.encode(line) for line in story_lines] story_token_ids = [ token for sentence in story_lines_token_ids for token in sentence ] + summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines] summary_token_ids = [ token for sentence in summary_lines_token_ids for token in sentence ] diff --git a/requirements.txt b/requirements.txt index 9c43abc6d7..060aba915d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ regex # For XLNet sentencepiece # For XLM -sacremoses \ No newline at end of file +sacremoses +# For ROUGE +pyrouge diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py index 09e340a150..e1b2d23da0 100644 --- a/transformers/generate/beam_search.py +++ b/transformers/generate/beam_search.py @@ -26,27 +26,31 @@ Use Beam Search to generate sequences using encoder-decoder models. import torch from torch import nn +import logging + + +logger = logging.getLogger(__name__) + class BeamSearch(nn.Module): def __init__( self, model, - tokenizer, + bos_token_id, + pad_token_id, + eos_token_id, + batch_size, beam_size, min_length, max_length, - batch_size=1, alpha=0, block_repeating_trigrams=True, + device=torch.device("cpu"), ): r""" Inputs: **model**: instance of ``transformers.PreTrainedEncoderDecoder`` The pretrained encoder-decoder model that will be used to generate the sequences. - **tokenizer**: instance of ``transformers.PreTrainedTokenizer`` - The pretrained tokenizer associated to the model used in the encoder-decoder. We only - support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer - needs to be initialized or this function will raise and exception. **batch_size**: (`optional`) int Batch size of the inputs. The value is set automatically when calling `forward`. **beam_size**: int @@ -64,11 +68,11 @@ class BeamSearch(nn.Module): """ super(BeamSearch, self).__init__() self.model = model - self.tokenizer = tokenizer + self.device = device - self.bos_token_id = tokenizer.bos_token_id - self.eos_token_id = tokenizer.eos_token_id - self.pad_token_id = tokenizer.pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id self.batch_size = batch_size self.beam_size = beam_size @@ -90,15 +94,24 @@ class BeamSearch(nn.Module): def _init_beam_state(self, batch_size): """ (re-)Initialize the state of the beams. """ self.hypotheses = [[] for _ in range(batch_size)] - self.batch_offset = torch.arange(batch_size, dtype=torch.long) + self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device) self.beam_offset = torch.arange( - 0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long + 0, + batch_size * self.beam_size, + step=self.beam_size, + dtype=torch.long, + device=self.device, ) self.growing_beams = torch.full( - (batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long + (batch_size * self.beam_size, 1), + self.bos_token_id, + dtype=torch.long, + device=self.device, ) self.topk_log_probabilities = torch.tensor( - [0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float + [0.0] + [float("-inf")] * (self.beam_size - 1), + dtype=torch.float, + device=self.device, ).repeat(batch_size) self.results = { "predictions": [[] for _ in range(batch_size)], @@ -136,28 +149,37 @@ class BeamSearch(nn.Module): ) # forward pass on the encoder - encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder) + encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder) + encoder_hidden_states = encoder_outputs[0] kwargs_decoder["encoder_hidden_states"] = tile( - encoder_outputs, self.beam_size, dim=0 + encoder_hidden_states, self.beam_size, dim=0 + ) + kwargs_decoder["encoder_attention_mask"] = tile( + kwargs_encoder["attention_mask"], self.beam_size, dim=0 ) # grow the beam by generating sequences in an autoregressive way - batch_size = encoder_input_ids.size(0) + batch_size, block_size = encoder_input_ids.size() self._init_beam_state(batch_size) for step in range(self.max_length): - # prepare the decoder input - decoder_input = fit_to_block_size( - self.growing_beams, self.tokenizer.pad_token_id - ) - kwargs_decoder["decoder_lm_labels"] = build_lm_labels( - decoder_input, self.tokenizer.pad_token_id - ) - kwargs_decoder["decoder_attention_mask"] = build_mask( - decoder_input, self.tokenizer.pad_token_id + # Add padding tokens + decoder_input = torch.full( + (self.growing_beams.size(0), block_size), + self.pad_token_id, + dtype=torch.long, + device=self.growing_beams.device, ) + decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams - outputs = self.model.decoder(decoder_input, kwargs_decoder) - log_probabilities = torch.nn.functional.log_softmax(outputs[1]) + # compute decoder_attention_mask + decoder_mask = torch.ones_like(decoder_input) + idx_pad_tokens = decoder_input == self.pad_token_id + decoder_mask[idx_pad_tokens] = 0 + kwargs_decoder["attention_mask"] = decoder_mask + + outputs = self.model.decoder(decoder_input, **kwargs_decoder) + last_token_scores = outputs[0][:, -1, :].squeeze(1) + log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0) surviving_beams_rows = self.grow(log_probabilities) if self.is_done: break @@ -189,13 +211,13 @@ class BeamSearch(nn.Module): # Find the `beam_size` (previous_beam + token) combinations with # the highest score - topk_log_probabilities, topk_ids = torch.topk( + self.topk_log_probabilities, topk_ids = torch.topk( log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1 ) # Apply the length penalty. The +1 accounts for the [EOS] token # that will be added if the beam ends. - topk_scores = topk_log_probabilities + topk_scores = self.topk_log_probabilities if self.apply_length_penalty: topk_scores /= self._length_penalty() @@ -337,8 +359,9 @@ def fit_to_block_size(sequence, block_size, pad_token_id): if len(sequence) > block_size: return sequence[:block_size] else: - sequence.extend([pad_token_id] * (block_size - len(sequence))) - return sequence + return torch.cat( + (sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0 + ) def build_lm_labels(sequence, pad_token_id):