From e15f0d73db464c8a7abaeeb2a78b0df142b9a0ec Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 18 Jan 2023 14:24:53 +0000 Subject: [PATCH] OPT: Fix batched generation with FLAX (#21150) * Fix Flax OPT numerical masking * re-enable test * add fix to bart and reintroduce copied from in opt --- .../models/bart/modeling_flax_bart.py | 2 +- .../blenderbot/modeling_flax_blenderbot.py | 2 +- .../modeling_flax_blenderbot_small.py | 2 +- .../models/marian/modeling_flax_marian.py | 2 +- .../models/mbart/modeling_flax_mbart.py | 2 +- .../models/opt/modeling_flax_opt.py | 2 +- .../models/pegasus/modeling_flax_pegasus.py | 2 +- tests/models/opt/test_modeling_flax_opt.py | 58 +++++++++---------- 8 files changed, 34 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 90ddfa57cb..a17e4185fd 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 1b3b57b95b..8abacc090b 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index e5a0352d24..fc7c938512 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index da2e4a1fe5..74f6b1fe8d 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index afc67be57b..a99b58fb5b 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py index 1237e3b25f..83af4da8b0 100644 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ b/src/transformers/models/opt/modeling_flax_opt.py @@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index c4ecd25b6e..ab246b5ef8 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) else: attention_bias = None diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py index 402e556cef..04d1f75db1 100644 --- a/tests/models/opt/test_modeling_flax_opt.py +++ b/tests/models/opt/test_modeling_flax_opt.py @@ -364,43 +364,39 @@ class FlaxOPTGenerationTest(unittest.TestCase): self.assertIsNotNone(output_string, EXPECTED_OUTPUTS) - # TODO fix in the following PR - # def test_batch_generation(self): - # model_id = "facebook/opt-350m" + def test_batch_generation(self): + model_id = "facebook/opt-350m" - # tokenizer = GPT2Tokenizer.from_pretrained(model_id) - # model = FlaxOPTForCausalLM.from_pretrained(model_id) + tokenizer = GPT2Tokenizer.from_pretrained(model_id) + model = FlaxOPTForCausalLM.from_pretrained(model_id) - # tokenizer.padding_side = "left" + tokenizer.padding_side = "left" - # # use different length sentences to test batching - # sentences = [ - # "Hello, my dog is a little", - # "Today, I", - # ] + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] - # inputs = tokenizer(sentences, return_tensors="jax", padding=True) - # input_ids = inputs["input_ids"] + inputs = tokenizer(sentences, return_tensors="jax", padding=True) + input_ids = inputs["input_ids"] - # outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False) + outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False) - # inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids - # output_non_padded = model.generate(input_ids=inputs_non_padded) + inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids + output_non_padded = model.generate(input_ids=inputs_non_padded) - # num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum() - # inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids - # output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum() + inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) - # batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) - # non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True) - # padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True) + batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True) - # expected_output_sentence = [ - # "Hello, my dog is a little bit of a dork.\nI'm a little bit", - # "Today, I" - # # TODO fix this test in next PR - # # "Today, I was in the middle of a conversation with a friend about the", - # ] - # self.assertListEqual(expected_output_sentence, batch_out_sentence) - # # TODO outputs will be similar, fix in next PR - # self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence]) + expected_output_sentence = [ + "Hello, my dog is a little bit of a dork.\nI'm a little bit", + "Today, I was in the middle of a conversation with a friend about the", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])