Constrained Beam Search [*With* Disjunctive Decoding] (#15761)

* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements

* finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing.

* fixed bug found in constrained beam search that used beam_idx that were not global across all the batches

* disjunctive constraint working 100% correctly

* passing all tests

* Accidentally included mlruns

* Update src/transformers/generation_beam_constraints.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_beam_constraints.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* complete overhaul of type complexities and other nits

* strict type checks in generate()

* fixing second round of feedback by narsil

* fixed failing generation test because of type check overhaul

* generation test fail fix

* fixing test fails

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Chan Woo Kim
2022-03-05 02:18:34 +09:00
committed by GitHub
parent 040c11f6da
commit 5c6f57ee75
9 changed files with 587 additions and 76 deletions

View File

@@ -39,7 +39,7 @@ if is_torch_available():
VisionEncoderDecoderModel,
top_k_top_p_filtering,
)
from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import (
ForcedBOSTokenLogitsProcessor,
@@ -1202,7 +1202,7 @@ class GenerationTesterMixin:
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
@@ -1227,7 +1227,7 @@ class GenerationTesterMixin:
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
# Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
@@ -1288,7 +1288,7 @@ class GenerationTesterMixin:
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
min_id = 3
max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
@@ -1499,18 +1499,23 @@ class GenerationTesterMixin:
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
if not isinstance(tensor_1, list):
tensor_1 = tensor_1.cpu().tolist()
if not isinstance(tensor_2, list):
tensor_2 = tensor_2.cpu().tolist()
in_order = len(tensor_1) <= len(tensor_2)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
chunk_size = len(shorter)
for chunk_idx in range(len(longer) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
if subseq == shorter:
flag = True
break
@@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
constraints = [
PhrasalConstraint(force_tokens),
@@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase):
],
)
@slow
def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
).input_ids
constraints = [
PhrasalConstraint(force_phrase),
DisjunctiveConstraint(flexible_phrases),
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
# max_length=20,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
],
)
@slow
def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
force_words_ids = [
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
],
)
@slow
def test_constrained_beam_search_example_translation_mixin(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["sind"]
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
@slow
def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
@@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase):
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[-1]])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])