🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)
This commit is contained in:
@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
batch_size = 1
|
||||
num_beams = 3
|
||||
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids = input_ids.expand(num_beams, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
|
||||
# pretend decoder_input_ids correspond to first encoder input id
|
||||
decoder_input_ids = input_ids[:, :1]
|
||||
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
generated_ids = bart_model.beam_search(
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
beam_scorer_no_max_len = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
generated_ids_no_max_len = bart_model.beam_search(
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer_no_max_len,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# BeamSearchScorer max_length should not influence "real" max_length
|
||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
|
||||
Reference in New Issue
Block a user