Add support for beam search's num_return_sequencs flag in flax (#23082)
* add code for numReturnSeq * add flax support for num return sequences * Make Fix up for changes * add test for num return sequences * lint
This commit is contained in:
@@ -158,6 +158,19 @@ class FlaxGenerationTesterMixin:
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_num_return_sequences(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
config.num_return_sequences = 2
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[0], input_ids.shape[0] * config.num_return_sequences)
|
||||
|
||||
def test_sample_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
|
||||
Reference in New Issue
Block a user