Add support for beam search's num_return_sequencs flag in flax (#23082)
* add code for numReturnSeq * add flax support for num return sequences * Make Fix up for changes * add test for num return sequences * lint
This commit is contained in:
@@ -463,6 +463,7 @@ class FlaxGenerationMixin:
|
|||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
trace=trace,
|
trace=trace,
|
||||||
params=params,
|
params=params,
|
||||||
|
num_return_sequences=generation_config.num_return_sequences,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -749,6 +750,7 @@ class FlaxGenerationMixin:
|
|||||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
|
num_return_sequences: Optional[int] = None,
|
||||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -793,6 +795,9 @@ class FlaxGenerationMixin:
|
|||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
|
||||||
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
|
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
|
||||||
|
num_return_sequences = (
|
||||||
|
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, num_beams, cur_len = input_ids.shape
|
batch_size, num_beams, cur_len = input_ids.shape
|
||||||
|
|
||||||
@@ -996,8 +1001,8 @@ class FlaxGenerationMixin:
|
|||||||
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
||||||
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||||
|
|
||||||
# take best beam for each batch
|
# Take best beams for each batch (the score is sorted in descending order)
|
||||||
sequences = sequences[:, 0]
|
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
||||||
scores = scores[:, 0]
|
scores = flatten_beam_dim(scores[:, :num_return_sequences])
|
||||||
|
|
||||||
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
||||||
|
|||||||
@@ -158,6 +158,19 @@ class FlaxGenerationTesterMixin:
|
|||||||
|
|
||||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_beam_search_generate_num_return_sequences(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.do_sample = False
|
||||||
|
config.max_length = max_length
|
||||||
|
config.num_beams = 2
|
||||||
|
config.num_return_sequences = 2
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[0], input_ids.shape[0] * config.num_return_sequences)
|
||||||
|
|
||||||
def test_sample_generate_logits_warper(self):
|
def test_sample_generate_logits_warper(self):
|
||||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
config.do_sample = True
|
config.do_sample = True
|
||||||
|
|||||||
Reference in New Issue
Block a user