[Flax] Add Beam Search (#12131)

* fix_torch_device_generate_test

* remove @

* push new logit processors

* add processors

* save first working version

* save intermediate

* finish

* make style

* make fix-copies

* finish

* Update tests/test_modeling_flax_bart.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Patrick von Platen
2021-06-16 09:43:54 +01:00
committed by GitHub
parent 802ffaff0d
commit c3c39f7e84
8 changed files with 840 additions and 44 deletions

View File

@@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
@@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
config.temperature = 0.8
config.top_k = 10
config.top_p = 0.3
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.num_beams = 2
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
@@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.num_beams = 2
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())