tweaks to the BeamSearch API

This commit is contained in:
Rémi Louf
2019-11-08 11:16:26 +01:00
committed by Julien Chaumond
parent ba089c780b
commit 4735c2af07
2 changed files with 59 additions and 57 deletions

View File

@@ -1,15 +1,22 @@
from collections import namedtuple
import unittest
import pytest
import numpy as np
import torch
from torch import nn
from transformers.generate import BeamSearch
from transformers import PreTrainedEncoderDecoder
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
class StubTransformer(nn.Module):
def __init__(self):
self.encoder = None
self.decoder = None
self._parameters = {"dumy": torch.tensor([1])}
def forward(self):
pass
class BeamSearchtest(unittest.TestCase):
@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
class will break the integration with the beam search.
"""
model = PreTrainedEncoderDecoder("encoder", "decoder")
tokenizer = StubTokenizer(0, 1, 2)
model = StubTransformer()
try:
_ = BeamSearch(
model=model,
tokenizer=tokenizer,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=1,
beam_size=1,
min_length=1,
@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
min_length = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=eos_idx,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=5,
@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[0, eos_idx] = score_distribution[0]
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
log_probabilities[0, idx] = score
pytest.set_trace()
for step in range(1, min_length + 2):
log_probabilities[0, eos_idx] = score_distribution[0]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows = beam.grow(log_probabilities)
if step < min_length:
np.testing.assert_array_equal(
beam.growing_beams.numpy(),
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
)
elif step == min_length:
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 10
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
log_probabilities[::beam_size, idx] = score
surviving_beams_rows = beam.grow(log_probabilities)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
if step < 7:
self.assertFalse(
@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
np.array([-1e20] * vocab_size, dtype="float32"),
)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_example_for_one_step(self):
""" We test that the predictions for one step of growth are correct. """
batch_size = 2
@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
vocab_size = 5
beam = BeamSearch(
model=StubTransformer("encoder", "decoder"),
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,