Constrained Beam Search [without disjunctive decoding] (#15416)

* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Chan Woo Kim
2022-02-10 00:59:26 +09:00
committed by GitHub
parent 0113aae5b7
commit 2b5603f6ac
8 changed files with 1871 additions and 16 deletions

View File

@@ -25,7 +25,8 @@ from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
class BeamSearchTester:
@@ -232,6 +233,270 @@ class BeamSearchTester:
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
class ConstrainedBeamSearchTester:
def __init__(
self,
parent,
constraints=None,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
self.constraints = constraints
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_constrained_beam_scorer(self, **kwargs):
return ConstrainedBeamSearchScorer(
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
scores_for_all_vocab, _ = (
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_constrained_beam_scorer_update(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# check all batches are done
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# beam scorer should be done
self.parent.assertTrue(constrained_beam_scorer.is_done)
# check
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_constrained_beam_scorer_finalize(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
)
constraints = constrained_beam_scorer.constraints
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
for output in sequences:
for constraint in constraints:
forced_token_ids = constraint.token_ids
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
)
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
return flag
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
@@ -248,3 +513,21 @@ class BeamSearchTest(unittest.TestCase):
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)
@require_torch
class ConstrainedBeamSearchTest(unittest.TestCase):
def setUp(self):
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
def test_constrained_beam_hypotheses(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
def test_constrained_beam_scorer_update(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
def test_constrained_beam_scorer_finalize(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)

View File

@@ -27,6 +27,8 @@ if is_torch_available():
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
@@ -37,7 +39,8 @@ if is_torch_available():
VisionEncoderDecoderModel,
top_k_top_p_filtering,
)
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
@@ -190,6 +193,25 @@ class GenerationTesterMixin:
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=batch_size,
constraints=constraints,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
@@ -526,6 +548,69 @@ class GenerationTesterMixin:
)
return output_generate, output_group_beam_search
def _constrained_beam_search_generate(
self,
model,
input_ids,
attention_mask,
max_length,
constrained_beam_scorer,
constraints,
beam_kwargs,
logits_processor,
logits_process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=constrained_beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_group_beam_search = model.constrained_beam_search(
input_ids_clone,
constrained_beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_group_beam_search
def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes:
@@ -719,6 +804,7 @@ class GenerationTesterMixin:
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
@@ -1085,6 +1171,164 @@ class GenerationTesterMixin:
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
# Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
num_return_sequences = 2
max_length = 20
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
max_length = 20
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
config.eos_token_id,
config.forced_bos_token_id,
config.forced_eos_token_id,
max_length,
)
# Sample constraints
if not input_ids.dtype == torch.float32:
min_id = torch.min(input_ids) + 3
max_id = torch.max(input_ids)
else:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, constraints, num_return_sequences=1
)
output_generate, output_beam_search = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
self.assertTrue(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
@@ -1254,6 +1498,24 @@ class GenerationTesterMixin:
[encoder_expected_shape] * len(hidden_states),
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
self.assertTrue(flag)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase):
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
constraints = [
PhrasalConstraint(force_tokens),
PhrasalConstraint(force_tokens_2),
]
starting_text = ["The soldiers were not prepared and"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
max_length=30,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
],
)
@slow
def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# lets run beam search using 5 beams
num_beams = 5
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
model_kwargs = {
"encoder_outputs": model.get_encoder()(
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
)
}
constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
# instantiate beam scorer
beam_scorer = ConstrainedBeamSearchScorer(
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
)
# instantiate logits processors
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
]
)
outputs = model.constrained_beam_search(
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])