[Generation] Fix Transition probs (#17311)

* [Draft] fix transition probs

* up

* up

* up

* make it work

* fix

* finish

* update
This commit is contained in:
Patrick von Platen
2022-05-19 22:17:02 +02:00
committed by GitHub
parent e8714c0307
commit 518bd02c9b
4 changed files with 199 additions and 77 deletions

View File

@@ -126,7 +126,11 @@ class BeamSearchTester:
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
@@ -136,7 +140,7 @@ class BeamSearchTester:
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
@@ -161,10 +165,15 @@ class BeamSearchTester:
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
@@ -188,6 +197,8 @@ class BeamSearchTester:
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
@@ -196,6 +207,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
@@ -225,6 +237,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
beam_indices=beam_indices,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
def check_constrained_beam_scorer_finalize(

View File

@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`,
# `num_beams` and `batch_size > 1`
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
torch_device
)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
result = model.generate(
input_ids,
max_length=10,
return_dict_in_generate=True,
output_scores=True,
forced_eos_token_id=model.config.eos_token_id,
num_beams=4,
do_sample=False,
num_return_sequences=3,
length_penalty=0.0,
)
transition_scores = model.compute_transition_beam_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
)
sum_transition_scores = torch.sum(transition_scores, dim=1)
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
def test_log_scores_sample_decoder_only(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=15,
return_dict_in_generate=True,
do_sample=False,
output_scores=True,
)
# decoder-only starts generating from `input_ids`
begin_generation = inputs.input_ids.shape[-1]
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
def test_log_scores_sample_encoder_decoder(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=3,
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
output_scores=True,
)
# encoder-decoder has one decoder_start_token_id by default
begin_generation = 1
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
@slow
def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously
@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]