Diverse beam search 2 (#9006)
* diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain <a.jain@sprinklr.com> Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
67ff1c314a
commit
02d0e0355c
@@ -17,15 +17,16 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
@@ -61,7 +62,7 @@ class GenerationTesterMixin:
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
@@ -70,6 +71,13 @@ class GenerationTesterMixin:
|
||||
}
|
||||
logits_processor = LogitsProcessorList(
|
||||
(
|
||||
[
|
||||
HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
|
||||
]
|
||||
if diversity_penalty is not None
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[
|
||||
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
|
||||
]
|
||||
@@ -115,6 +123,28 @@ class GenerationTesterMixin:
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": 2,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
"num_beam_groups": 2, # one beam per group
|
||||
"diversity_penalty": 2.0,
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=beam_kwargs["num_beam_groups"],
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
|
||||
encoder = model.get_encoder()
|
||||
@@ -408,6 +438,92 @@ class GenerationTesterMixin:
|
||||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
def test_group_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_group_beam_search = model.group_beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_group_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal for `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.group_beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
@@ -512,3 +628,31 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_diverse_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
||||
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
||||
The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both."""
|
||||
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = bart_model.generate(
|
||||
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
|
||||
)
|
||||
|
||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
|
||||
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user