[Beam Search] Correct returned beam scores (#14654)
* better * save intermediate * finish code * up * docs * Apply suggestions from code review * up * add compute transition beam scores function to model and make sure scores are correct with eos * apply nicos comments * Apply suggestions from code review * another fix
This commit is contained in:
committed by
GitHub
parent
e239fc3b0b
commit
8d6acc6c29
@@ -1903,3 +1903,147 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||
|
||||
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_decoder_only(self):
|
||||
articles = [
|
||||
"Justin Timberlake",
|
||||
"Michael Phelps",
|
||||
]
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-gpt2",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_sample_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
do_sample=True,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_group_beam_search_encoder_decoder(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=2,
|
||||
num_beam_groups=2,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||
)
|
||||
transition_scores_sum = transition_scores.sum(-1)
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user