[Generation] Fix Transition probs (#17311)
* [Draft] fix transition probs * up * up * up * make it work * fix * finish * update
This commit is contained in:
committed by
GitHub
parent
e8714c0307
commit
518bd02c9b
@@ -126,7 +126,11 @@ class BeamSearchTester:
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||
beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||
)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(beam_scorer.is_done)
|
||||
|
||||
@@ -136,7 +140,7 @@ class BeamSearchTester:
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
@@ -161,10 +165,15 @@ class BeamSearchTester:
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
expected_beam_indices = list(range(10))
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
expected_beam_indices + [next_indices[batch_idx, 1].item()],
|
||||
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
||||
)
|
||||
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
@@ -188,6 +197,8 @@ class BeamSearchTester:
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||
sequence_output = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
@@ -196,6 +207,7 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
@@ -225,6 +237,7 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
|
||||
def check_constrained_beam_scorer_finalize(
|
||||
|
||||
@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_transition_scores_early_stopping(self):
|
||||
# This is an aggressive test that makes sure that `beam_search's`
|
||||
# transition scores are computed correctly for varying `num_return_sequences`,
|
||||
# `num_beams` and `batch_size > 1`
|
||||
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
input_ids,
|
||||
max_length=10,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
forced_eos_token_id=model.config.eos_token_id,
|
||||
num_beams=4,
|
||||
do_sample=False,
|
||||
num_return_sequences=3,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
|
||||
)
|
||||
|
||||
sum_transition_scores = torch.sum(transition_scores, dim=1)
|
||||
|
||||
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
|
||||
|
||||
def test_log_scores_sample_decoder_only(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=15,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# decoder-only starts generating from `input_ids`
|
||||
begin_generation = inputs.input_ids.shape[-1]
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
def test_log_scores_sample_encoder_decoder(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=3,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# encoder-decoder has one decoder_start_token_id by default
|
||||
begin_generation = 1
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# exactly the example provided in the docstrings of beam search, which previously
|
||||
@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
flexible_phrases = tokenizer(
|
||||
@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_word = "scared"
|
||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||
|
||||
Reference in New Issue
Block a user