tweaks to the BeamSearch API
This commit is contained in:
committed by
Julien Chaumond
parent
ba089c780b
commit
4735c2af07
@@ -32,7 +32,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BeamSearch(nn.Module):
|
||||
class BeamSearch(object):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@@ -45,12 +45,17 @@ class BeamSearch(nn.Module):
|
||||
max_length,
|
||||
alpha=0,
|
||||
block_repeating_trigrams=True,
|
||||
device=torch.device("cpu"),
|
||||
):
|
||||
r"""
|
||||
Inputs:
|
||||
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
||||
The pretrained encoder-decoder model that will be used to generate the sequences.
|
||||
**bos_token_id**: int
|
||||
Id that is used by the tokenizer to represent the beggining of a sentence.
|
||||
**pad_token_id**: int
|
||||
Id that is used by the tokenizer for padding.
|
||||
**eos_token_id**: int
|
||||
Id that is used by the tokenizer to represent the end of a sentence.
|
||||
**batch_size**: (`optional`) int
|
||||
Batch size of the inputs. The value is set automatically when calling `forward`.
|
||||
**beam_size**: int
|
||||
@@ -68,7 +73,7 @@ class BeamSearch(nn.Module):
|
||||
"""
|
||||
super(BeamSearch, self).__init__()
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
@@ -86,10 +91,7 @@ class BeamSearch(nn.Module):
|
||||
self._init_beam_state(batch_size)
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
return self.growing_beams.size(1)
|
||||
except NameError:
|
||||
return 0
|
||||
return self.growing_beams.size(1)
|
||||
|
||||
def _init_beam_state(self, batch_size):
|
||||
""" (re-)Initialize the state of the beams. """
|
||||
@@ -120,7 +122,7 @@ class BeamSearch(nn.Module):
|
||||
self._step = 0
|
||||
self.is_done = False
|
||||
|
||||
def forward(self, encoder_input_ids, **model_kwargs):
|
||||
def __call__(self, encoder_input_ids, **model_kwargs):
|
||||
""" Generate a sequence using Beam Search. """
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
@@ -158,28 +160,17 @@ class BeamSearch(nn.Module):
|
||||
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
||||
)
|
||||
|
||||
# grow the beam by generating sequences in an autoregressive way
|
||||
# grow the beam iteratively
|
||||
batch_size, block_size = encoder_input_ids.size()
|
||||
self._init_beam_state(batch_size)
|
||||
for step in range(self.max_length):
|
||||
# Add padding tokens
|
||||
decoder_input = torch.full(
|
||||
(self.growing_beams.size(0), block_size),
|
||||
self.pad_token_id,
|
||||
dtype=torch.long,
|
||||
device=self.growing_beams.device,
|
||||
)
|
||||
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
|
||||
|
||||
# compute decoder_attention_mask
|
||||
decoder_mask = torch.ones_like(decoder_input)
|
||||
idx_pad_tokens = decoder_input == self.pad_token_id
|
||||
decoder_mask[idx_pad_tokens] = 0
|
||||
kwargs_decoder["attention_mask"] = decoder_mask
|
||||
|
||||
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
||||
kwargs_decoder["attention_mask"] = build_mask(decoder_input)
|
||||
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
||||
last_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0)
|
||||
|
||||
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
|
||||
surviving_beams_rows = self.grow(log_probabilities)
|
||||
if self.is_done:
|
||||
break
|
||||
@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding tokens to the right.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
return torch.cat(
|
||||
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0
|
||||
)
|
||||
|
||||
|
||||
def build_lm_labels(sequence, pad_token_id):
|
||||
""" Padding token, encoded as 0, are represented by the value -1 so they
|
||||
are not taken into account in the loss computation. """
|
||||
padded = sequence.clone()
|
||||
padded[padded == pad_token_id] = -1
|
||||
return padded
|
||||
padded_sequence = torch.full(
|
||||
(sequence.size(0), block_size),
|
||||
pad_token_id,
|
||||
dtype=torch.long,
|
||||
device=sequence.device,
|
||||
)
|
||||
padded_sequence[:, : sequence.size(1)] = sequence
|
||||
return sequence
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
from collections import namedtuple
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.generate import BeamSearch
|
||||
from transformers import PreTrainedEncoderDecoder
|
||||
|
||||
|
||||
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
|
||||
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
|
||||
class StubTransformer(nn.Module):
|
||||
def __init__(self):
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self._parameters = {"dumy": torch.tensor([1])}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
|
||||
class BeamSearchtest(unittest.TestCase):
|
||||
@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
|
||||
class will break the integration with the beam search.
|
||||
"""
|
||||
|
||||
model = PreTrainedEncoderDecoder("encoder", "decoder")
|
||||
tokenizer = StubTokenizer(0, 1, 2)
|
||||
model = StubTransformer()
|
||||
try:
|
||||
_ = BeamSearch(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=1,
|
||||
beam_size=1,
|
||||
min_length=1,
|
||||
@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
min_length = 5
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=eos_idx,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=5,
|
||||
@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
|
||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
||||
log_probabilities[0, idx] = score
|
||||
|
||||
pytest.set_trace()
|
||||
for step in range(1, min_length + 2):
|
||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||
|
||||
# Beam #3 and #4 teminate at the first step since the probability
|
||||
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
||||
# The top beam (most likely) always ends with 4 until we reach min_length.
|
||||
surviving_beams_rows = beam.grow(log_probabilities)
|
||||
if step < min_length:
|
||||
np.testing.assert_array_equal(
|
||||
beam.growing_beams.numpy(),
|
||||
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
|
||||
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
|
||||
)
|
||||
elif step == min_length:
|
||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
||||
@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 10
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 10
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
|
||||
log_probabilities[::beam_size, idx] = score
|
||||
|
||||
surviving_beams_rows = beam.grow(log_probabilities)
|
||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||
|
||||
if step < 7:
|
||||
self.assertFalse(
|
||||
@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
|
||||
np.array([-1e20] * vocab_size, dtype="float32"),
|
||||
)
|
||||
|
||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||
|
||||
def test_beam_search_example_for_one_step(self):
|
||||
""" We test that the predictions for one step of growth are correct. """
|
||||
batch_size = 2
|
||||
@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 5
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
|
||||
Reference in New Issue
Block a user