tweaks to the BeamSearch API
This commit is contained in:
committed by
Julien Chaumond
parent
ba089c780b
commit
4735c2af07
@@ -1,15 +1,22 @@
|
||||
from collections import namedtuple
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.generate import BeamSearch
|
||||
from transformers import PreTrainedEncoderDecoder
|
||||
|
||||
|
||||
StubTokenizer = namedtuple("Tokenizer", ["bos_token_id", "eos_token_id", "pad_token_id"])
|
||||
StubTransformer = namedtuple("Transformer", ["encoder", "decoder"])
|
||||
class StubTransformer(nn.Module):
|
||||
def __init__(self):
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self._parameters = {"dumy": torch.tensor([1])}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
|
||||
class BeamSearchtest(unittest.TestCase):
|
||||
@@ -18,12 +25,13 @@ class BeamSearchtest(unittest.TestCase):
|
||||
class will break the integration with the beam search.
|
||||
"""
|
||||
|
||||
model = PreTrainedEncoderDecoder("encoder", "decoder")
|
||||
tokenizer = StubTokenizer(0, 1, 2)
|
||||
model = StubTransformer()
|
||||
try:
|
||||
_ = BeamSearch(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=1,
|
||||
beam_size=1,
|
||||
min_length=1,
|
||||
@@ -46,8 +54,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
min_length = 5
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=eos_idx, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=eos_idx,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=5,
|
||||
@@ -71,17 +81,17 @@ class BeamSearchtest(unittest.TestCase):
|
||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
||||
log_probabilities[0, idx] = score
|
||||
|
||||
pytest.set_trace()
|
||||
for step in range(1, min_length + 2):
|
||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
||||
|
||||
# Beam #3 and #4 teminate at the first step since the probability
|
||||
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
||||
# The top beam (most likely) always ends with 4 until we reach min_length.
|
||||
surviving_beams_rows = beam.grow(log_probabilities)
|
||||
if step < min_length:
|
||||
np.testing.assert_array_equal(
|
||||
beam.growing_beams.numpy(),
|
||||
np.repeat(np.array([[0] + [4] * step]), 2, axis=0),
|
||||
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
|
||||
)
|
||||
elif step == min_length:
|
||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
||||
@@ -99,8 +109,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 10
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
@@ -140,8 +152,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 10
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
@@ -167,7 +181,6 @@ class BeamSearchtest(unittest.TestCase):
|
||||
log_probabilities[::beam_size, idx] = score
|
||||
|
||||
surviving_beams_rows = beam.grow(log_probabilities)
|
||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||
|
||||
if step < 7:
|
||||
self.assertFalse(
|
||||
@@ -182,6 +195,8 @@ class BeamSearchtest(unittest.TestCase):
|
||||
np.array([-1e20] * vocab_size, dtype="float32"),
|
||||
)
|
||||
|
||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
||||
|
||||
def test_beam_search_example_for_one_step(self):
|
||||
""" We test that the predictions for one step of growth are correct. """
|
||||
batch_size = 2
|
||||
@@ -190,8 +205,10 @@ class BeamSearchtest(unittest.TestCase):
|
||||
vocab_size = 5
|
||||
|
||||
beam = BeamSearch(
|
||||
model=StubTransformer("encoder", "decoder"),
|
||||
tokenizer=StubTokenizer(bos_token_id=0, eos_token_id=1, pad_token_id=2),
|
||||
model=StubTransformer(),
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
pad_token_id=2,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
min_length=2,
|
||||
|
||||
Reference in New Issue
Block a user