From 4735c2af0715c24d47b34c167fb7d5543493b87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 8 Nov 2019 11:16:26 +0100 Subject: [PATCH] tweaks to the BeamSearch API --- transformers/generate/beam_search.py | 63 ++++++++++--------------- transformers/tests/beam_search_tests.py | 53 ++++++++++++++------- 2 files changed, 59 insertions(+), 57 deletions(-) diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py index e1b2d23da0..a18d20f31a 100644 --- a/transformers/generate/beam_search.py +++ b/transformers/generate/beam_search.py @@ -32,7 +32,7 @@ import logging logger = logging.getLogger(__name__) -class BeamSearch(nn.Module): +class BeamSearch(object): def __init__( self, model, @@ -45,12 +45,17 @@ class BeamSearch(nn.Module): max_length, alpha=0, block_repeating_trigrams=True, - device=torch.device("cpu"), ): r""" Inputs: **model**: instance of ``transformers.PreTrainedEncoderDecoder`` The pretrained encoder-decoder model that will be used to generate the sequences. + **bos_token_id**: int + Id that is used by the tokenizer to represent the beggining of a sentence. + **pad_token_id**: int + Id that is used by the tokenizer for padding. + **eos_token_id**: int + Id that is used by the tokenizer to represent the end of a sentence. **batch_size**: (`optional`) int Batch size of the inputs. The value is set automatically when calling `forward`. **beam_size**: int @@ -68,7 +73,7 @@ class BeamSearch(nn.Module): """ super(BeamSearch, self).__init__() self.model = model - self.device = device + self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id @@ -86,10 +91,7 @@ class BeamSearch(nn.Module): self._init_beam_state(batch_size) def __len__(self): - try: - return self.growing_beams.size(1) - except NameError: - return 0 + return self.growing_beams.size(1) def _init_beam_state(self, batch_size): """ (re-)Initialize the state of the beams. """ @@ -120,7 +122,7 @@ class BeamSearch(nn.Module): self._step = 0 self.is_done = False - def forward(self, encoder_input_ids, **model_kwargs): + def __call__(self, encoder_input_ids, **model_kwargs): """ Generate a sequence using Beam Search. """ # keyword arguments come in 3 flavors: encoder-specific (prefixed by # `encoder_`), decoder-specific (prefixed by `decoder_`) and those @@ -158,28 +160,17 @@ class BeamSearch(nn.Module): kwargs_encoder["attention_mask"], self.beam_size, dim=0 ) - # grow the beam by generating sequences in an autoregressive way + # grow the beam iteratively batch_size, block_size = encoder_input_ids.size() self._init_beam_state(batch_size) for step in range(self.max_length): - # Add padding tokens - decoder_input = torch.full( - (self.growing_beams.size(0), block_size), - self.pad_token_id, - dtype=torch.long, - device=self.growing_beams.device, - ) - decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams - - # compute decoder_attention_mask - decoder_mask = torch.ones_like(decoder_input) - idx_pad_tokens = decoder_input == self.pad_token_id - decoder_mask[idx_pad_tokens] = 0 - kwargs_decoder["attention_mask"] = decoder_mask + decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id) + kwargs_decoder["attention_mask"] = build_mask(decoder_input) outputs = self.model.decoder(decoder_input, **kwargs_decoder) - last_token_scores = outputs[0][:, -1, :].squeeze(1) - log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0) + + next_token_scores = outputs[0][:, -1, :].squeeze(1) + log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0) surviving_beams_rows = self.grow(log_probabilities) if self.is_done: break @@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id): """ Adapt the source and target sequences' lengths to the block size. If the sequence is shorter we append padding tokens to the right. """ - if len(sequence) > block_size: - return sequence[:block_size] - else: - return torch.cat( - (sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0 - ) - - -def build_lm_labels(sequence, pad_token_id): - """ Padding token, encoded as 0, are represented by the value -1 so they - are not taken into account in the loss computation. """ - padded = sequence.clone() - padded[padded == pad_token_id] = -1 - return padded + padded_sequence = torch.full( + (sequence.size(0), block_size), + pad_token_id, + dtype=torch.long, + device=sequence.device, + ) + padded_sequence[:, : sequence.size(1)] = sequence + return sequence def build_mask(sequence, pad_token_id): diff --git a/transformers/tests/beam_search_tests.py b/transformers/tests/beam_search_tests.py index a92ebf3578..6f2a2b9c2f 100644 --- a/transformers/tests/beam_search_tests.py +++ b/transformers/tests/beam_search_tests.py @@ -1,15 +1,22 @@ from collections import namedtuple import unittest - +import pytest import numpy as np import torch +from torch import nn from transformers.generate import BeamSearch from transformers import PreTrainedEncoderDecoder -StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"]) -StubTransformer = namedtuple("Transformer", ["encoder", "decoder"]) +class StubTransformer(nn.Module): + def __init__(self): + self.encoder = None + self.decoder = None + self._parameters = {"dumy": torch.tensor([1])} + + def forward(self): + pass class BeamSearchtest(unittest.TestCase): @@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase): class will break the integration with the beam search. """ - model = PreTrainedEncoderDecoder("encoder", "decoder") - tokenizer = StubTokenizer(0, 1, 2) + model = StubTransformer() try: _ = BeamSearch( model=model, - tokenizer=tokenizer, + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, batch_size=1, beam_size=1, min_length=1, @@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase): min_length = 5 beam = BeamSearch( - model=StubTransformer("encoder", "decoder"), - tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2), + model=StubTransformer(), + bos_token_id=0, + eos_token_id=eos_idx, + pad_token_id=2, batch_size=batch_size, beam_size=beam_size, min_length=5, @@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase): log_probabilities[0, eos_idx] = score_distribution[0] for idx, score in zip(non_eos_idxs, score_distribution[1:]): log_probabilities[0, idx] = score - + pytest.set_trace() for step in range(1, min_length + 2): log_probabilities[0, eos_idx] = score_distribution[0] # Beam #3 and #4 teminate at the first step since the probability # of the [EOS] token is -1e20 > -\infty so there are only two beams left. + # The top beam (most likely) always ends with 4 until we reach min_length. surviving_beams_rows = beam.grow(log_probabilities) if step < min_length: np.testing.assert_array_equal( - beam.growing_beams.numpy(), - np.repeat(np.array([[0] + [4] * step]), 2, axis=0), + beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step) ) elif step == min_length: np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) @@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase): vocab_size = 10 beam = BeamSearch( - model=StubTransformer("encoder", "decoder"), - tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + model=StubTransformer(), + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, batch_size=batch_size, beam_size=beam_size, min_length=2, @@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase): vocab_size = 10 beam = BeamSearch( - model=StubTransformer("encoder", "decoder"), - tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + model=StubTransformer(), + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, batch_size=batch_size, beam_size=beam_size, min_length=2, @@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase): log_probabilities[::beam_size, idx] = score surviving_beams_rows = beam.grow(log_probabilities) - log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) if step < 7: self.assertFalse( @@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase): np.array([-1e20] * vocab_size, dtype="float32"), ) + log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) + def test_beam_search_example_for_one_step(self): """ We test that the predictions for one step of growth are correct. """ batch_size = 2 @@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase): vocab_size = 5 beam = BeamSearch( - model=StubTransformer("encoder", "decoder"), - tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2), + model=StubTransformer(), + bos_token_id=0, + eos_token_id=1, + pad_token_id=2, batch_size=batch_size, beam_size=beam_size, min_length=2,