[Flax] Add Beam Search (#12131)
* fix_torch_device_generate_test * remove @ * push new logit processors * add processors * save first working version * save intermediate * finish * make style * make fix-copies * finish * Update tests/test_modeling_flax_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
802ffaff0d
commit
c3c39f7e84
@@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
@@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
|
||||
config.temperature = 0.8
|
||||
config.top_k = 10
|
||||
config.top_p = 0.3
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
|
||||
config.num_beams = 2
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
Reference in New Issue
Block a user