tweaks to the BeamSearch API
This commit is contained in:
committed by
Julien Chaumond
parent
ba089c780b
commit
4735c2af07
@@ -32,7 +32,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BeamSearch(nn.Module):
|
class BeamSearch(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@@ -45,12 +45,17 @@ class BeamSearch(nn.Module):
|
|||||||
max_length,
|
max_length,
|
||||||
alpha=0,
|
alpha=0,
|
||||||
block_repeating_trigrams=True,
|
block_repeating_trigrams=True,
|
||||||
device=torch.device("cpu"),
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
||||||
The pretrained encoder-decoder model that will be used to generate the sequences.
|
The pretrained encoder-decoder model that will be used to generate the sequences.
|
||||||
|
**bos_token_id**: int
|
||||||
|
Id that is used by the tokenizer to represent the beggining of a sentence.
|
||||||
|
**pad_token_id**: int
|
||||||
|
Id that is used by the tokenizer for padding.
|
||||||
|
**eos_token_id**: int
|
||||||
|
Id that is used by the tokenizer to represent the end of a sentence.
|
||||||
**batch_size**: (`optional`) int
|
**batch_size**: (`optional`) int
|
||||||
Batch size of the inputs. The value is set automatically when calling `forward`.
|
Batch size of the inputs. The value is set automatically when calling `forward`.
|
||||||
**beam_size**: int
|
**beam_size**: int
|
||||||
@@ -68,7 +73,7 @@ class BeamSearch(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(BeamSearch, self).__init__()
|
super(BeamSearch, self).__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
@@ -86,10 +91,7 @@ class BeamSearch(nn.Module):
|
|||||||
self._init_beam_state(batch_size)
|
self._init_beam_state(batch_size)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
try:
|
|
||||||
return self.growing_beams.size(1)
|
return self.growing_beams.size(1)
|
||||||
except NameError:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _init_beam_state(self, batch_size):
|
def _init_beam_state(self, batch_size):
|
||||||
""" (re-)Initialize the state of the beams. """
|
""" (re-)Initialize the state of the beams. """
|
||||||
@@ -120,7 +122,7 @@ class BeamSearch(nn.Module):
|
|||||||
self._step = 0
|
self._step = 0
|
||||||
self.is_done = False
|
self.is_done = False
|
||||||
|
|
||||||
def forward(self, encoder_input_ids, **model_kwargs):
|
def __call__(self, encoder_input_ids, **model_kwargs):
|
||||||
""" Generate a sequence using Beam Search. """
|
""" Generate a sequence using Beam Search. """
|
||||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||||
@@ -158,28 +160,17 @@ class BeamSearch(nn.Module):
|
|||||||
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# grow the beam by generating sequences in an autoregressive way
|
# grow the beam iteratively
|
||||||
batch_size, block_size = encoder_input_ids.size()
|
batch_size, block_size = encoder_input_ids.size()
|
||||||
self._init_beam_state(batch_size)
|
self._init_beam_state(batch_size)
|
||||||
for step in range(self.max_length):
|
for step in range(self.max_length):
|
||||||
# Add padding tokens
|
|
||||||
decoder_input = torch.full(
|
|
||||||
(self.growing_beams.size(0), block_size),
|
|
||||||
self.pad_token_id,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.growing_beams.device,
|
|
||||||
)
|
|
||||||
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
|
|
||||||
|
|
||||||
# compute decoder_attention_mask
|
|
||||||
decoder_mask = torch.ones_like(decoder_input)
|
|
||||||
idx_pad_tokens = decoder_input == self.pad_token_id
|
|
||||||
decoder_mask[idx_pad_tokens] = 0
|
|
||||||
kwargs_decoder["attention_mask"] = decoder_mask
|
|
||||||
|
|
||||||
|
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
||||||
|
kwargs_decoder["attention_mask"] = build_mask(decoder_input)
|
||||||
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
||||||
last_token_scores = outputs[0][:, -1, :].squeeze(1)
|
|
||||||
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0)
|
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||||
|
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
|
||||||
surviving_beams_rows = self.grow(log_probabilities)
|
surviving_beams_rows = self.grow(log_probabilities)
|
||||||
if self.is_done:
|
if self.is_done:
|
||||||
break
|
break
|
||||||
@@ -356,20 +347,14 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
|
|||||||
""" Adapt the source and target sequences' lengths to the block size.
|
""" Adapt the source and target sequences' lengths to the block size.
|
||||||
If the sequence is shorter we append padding tokens to the right.
|
If the sequence is shorter we append padding tokens to the right.
|
||||||
"""
|
"""
|
||||||
if len(sequence) > block_size:
|
padded_sequence = torch.full(
|
||||||
return sequence[:block_size]
|
(sequence.size(0), block_size),
|
||||||
else:
|
pad_token_id,
|
||||||
return torch.cat(
|
dtype=torch.long,
|
||||||
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0
|
device=sequence.device,
|
||||||
)
|
)
|
||||||
|
padded_sequence[:, : sequence.size(1)] = sequence
|
||||||
|
return sequence
|
||||||
def build_lm_labels(sequence, pad_token_id):
|
|
||||||
""" Padding token, encoded as 0, are represented by the value -1 so they
|
|
||||||
are not taken into account in the loss computation. """
|
|
||||||
padded = sequence.clone()
|
|
||||||
padded[padded == pad_token_id] = -1
|
|
||||||
return padded
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(sequence, pad_token_id):
|
def build_mask(sequence, pad_token_id):
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from transformers.generate import BeamSearch
|
from transformers.generate import BeamSearch
|
||||||
from transformers import PreTrainedEncoderDecoder
|
from transformers import PreTrainedEncoderDecoder
|
||||||
|
|
||||||
|
|
||||||
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
|
class StubTransformer(nn.Module):
|
||||||
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
|
def __init__(self):
|
||||||
|
self.encoder = None
|
||||||
|
self.decoder = None
|
||||||
|
self._parameters = {"dumy": torch.tensor([1])}
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchtest(unittest.TestCase):
|
class BeamSearchtest(unittest.TestCase):
|
||||||
@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
class will break the integration with the beam search.
|
class will break the integration with the beam search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = PreTrainedEncoderDecoder("encoder", "decoder")
|
model = StubTransformer()
|
||||||
tokenizer = StubTokenizer(0, 1, 2)
|
|
||||||
try:
|
try:
|
||||||
_ = BeamSearch(
|
_ = BeamSearch(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
pad_token_id=2,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
beam_size=1,
|
beam_size=1,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
min_length = 5
|
min_length = 5
|
||||||
|
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
model=StubTransformer("encoder", "decoder"),
|
model=StubTransformer(),
|
||||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
|
bos_token_id=0,
|
||||||
|
eos_token_id=eos_idx,
|
||||||
|
pad_token_id=2,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
min_length=5,
|
min_length=5,
|
||||||
@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||||
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
||||||
log_probabilities[0, idx] = score
|
log_probabilities[0, idx] = score
|
||||||
|
pytest.set_trace()
|
||||||
for step in range(1, min_length + 2):
|
for step in range(1, min_length + 2):
|
||||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||||
|
|
||||||
# Beam #3 and #4 teminate at the first step since the probability
|
# Beam #3 and #4 teminate at the first step since the probability
|
||||||
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
||||||
|
# The top beam (most likely) always ends with 4 until we reach min_length.
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
if step < min_length:
|
if step < min_length:
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
beam.growing_beams.numpy(),
|
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
|
||||||
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
|
|
||||||
)
|
)
|
||||||
elif step == min_length:
|
elif step == min_length:
|
||||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
||||||
@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
|
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
model=StubTransformer("encoder", "decoder"),
|
model=StubTransformer(),
|
||||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
pad_token_id=2,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
min_length=2,
|
min_length=2,
|
||||||
@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
vocab_size = 10
|
vocab_size = 10
|
||||||
|
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
model=StubTransformer("encoder", "decoder"),
|
model=StubTransformer(),
|
||||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
pad_token_id=2,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
min_length=2,
|
min_length=2,
|
||||||
@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
log_probabilities[::beam_size, idx] = score
|
log_probabilities[::beam_size, idx] = score
|
||||||
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
surviving_beams_rows = beam.grow(log_probabilities)
|
||||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
if step < 7:
|
if step < 7:
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
np.array([-1e20] * vocab_size, dtype="float32"),
|
np.array([-1e20] * vocab_size, dtype="float32"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||||
|
|
||||||
def test_beam_search_example_for_one_step(self):
|
def test_beam_search_example_for_one_step(self):
|
||||||
""" We test that the predictions for one step of growth are correct. """
|
""" We test that the predictions for one step of growth are correct. """
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
|
|||||||
vocab_size = 5
|
vocab_size = 5
|
||||||
|
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
model=StubTransformer("encoder", "decoder"),
|
model=StubTransformer(),
|
||||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
pad_token_id=2,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
min_length=2,
|
min_length=2,
|
||||||
|
|||||||
Reference in New Issue
Block a user