Constrained Beam Search [without disjunctive decoding] (#15416)
* added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -27,6 +27,8 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
@@ -37,7 +39,8 @@ if is_torch_available():
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
@@ -190,6 +193,25 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": num_return_sequences * 4,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
}
|
||||
beam_scorer = ConstrainedBeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
constraints=constraints,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
@@ -526,6 +548,69 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
def _constrained_beam_search_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
constrained_beam_scorer,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
logits_processor,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_interleave=constrained_beam_scorer.num_beams,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_group_beam_search = model.constrained_beam_search(
|
||||
input_ids_clone,
|
||||
constrained_beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
def test_greedy_generate(self):
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -719,6 +804,7 @@ class GenerationTesterMixin:
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
|
||||
@@ -1085,6 +1171,164 @@ class GenerationTesterMixin:
|
||||
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
||||
)
|
||||
|
||||
def test_constrained_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=1
|
||||
)
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
|
||||
# Sample constraints
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
num_return_sequences = 2
|
||||
max_length = 20
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.eos_token_id,
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
if not input_ids.dtype == torch.float32:
|
||||
min_id = torch.min(input_ids) + 3
|
||||
max_id = torch.max(input_ids)
|
||||
else:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, constraints, num_return_sequences=1
|
||||
)
|
||||
output_generate, output_beam_search = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constrained_beam_scorer=beam_scorer,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
|
||||
self.assertTrue(
|
||||
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
|
||||
)
|
||||
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
|
||||
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
|
||||
|
||||
for output in (output_beam_search, output_generate):
|
||||
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
@@ -1254,6 +1498,24 @@ class GenerationTesterMixin:
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# set to same device. we don't care what device.
|
||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
||||
|
||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = shorter.size(0)
|
||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if torch.equal(subseq, shorter):
|
||||
flag = True
|
||||
break
|
||||
|
||||
self.assertTrue(flag)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
@@ -2047,3 +2309,83 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
|
||||
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
PhrasalConstraint(force_tokens_2),
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers were not prepared and"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
constraints=constraints,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
max_length=30,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them off. They had no idea what to do",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_integration(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
# lets run beam search using 5 beams
|
||||
num_beams = 5
|
||||
# define decoder start token ids
|
||||
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
# add encoder_outputs to model keyword arguments
|
||||
model_kwargs = {
|
||||
"encoder_outputs": model.get_encoder()(
|
||||
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
|
||||
)
|
||||
}
|
||||
|
||||
constraint_str = "sind"
|
||||
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
|
||||
constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
|
||||
|
||||
# instantiate beam scorer
|
||||
beam_scorer = ConstrainedBeamSearchScorer(
|
||||
batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
|
||||
)
|
||||
|
||||
# instantiate logits processors
|
||||
logits_processor = LogitsProcessorList(
|
||||
[
|
||||
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
]
|
||||
)
|
||||
|
||||
outputs = model.constrained_beam_search(
|
||||
input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||
)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||
|
||||
Reference in New Issue
Block a user