[Generation] Fix Transition probs (#17311)
* [Draft] fix transition probs * up * up * up * make it work * fix * finish * update
This commit is contained in:
committed by
GitHub
parent
e8714c0307
commit
518bd02c9b
@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_transition_scores_early_stopping(self):
|
||||
# This is an aggressive test that makes sure that `beam_search's`
|
||||
# transition scores are computed correctly for varying `num_return_sequences`,
|
||||
# `num_beams` and `batch_size > 1`
|
||||
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
input_ids,
|
||||
max_length=10,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
forced_eos_token_id=model.config.eos_token_id,
|
||||
num_beams=4,
|
||||
do_sample=False,
|
||||
num_return_sequences=3,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
|
||||
)
|
||||
|
||||
sum_transition_scores = torch.sum(transition_scores, dim=1)
|
||||
|
||||
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
|
||||
|
||||
def test_log_scores_sample_decoder_only(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=15,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# decoder-only starts generating from `input_ids`
|
||||
begin_generation = inputs.input_ids.shape[-1]
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
def test_log_scores_sample_encoder_decoder(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=3,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# encoder-decoder has one decoder_start_token_id by default
|
||||
begin_generation = 1
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# exactly the example provided in the docstrings of beam search, which previously
|
||||
@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
flexible_phrases = tokenizer(
|
||||
@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_word = "scared"
|
||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||
|
||||
Reference in New Issue
Block a user