Flax Generate (#11777)

* fix_torch_device_generate_test

* remove @

* add

* indexing

* correct a couple of tests

* fix tests

* add logits processor

* finish top_k, top_p, temp

* add docs

* correct flax prng key default

* improve generate

* add generation docs

* add docs

* make style

* revert model outputs change

* make style

* correct typo

* fix tests

* fix slow test

* add raise

* finish generation

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-05-27 00:18:17 +01:00
committed by GitHub
parent 2df546918e
commit 996a315e76
11 changed files with 1080 additions and 96 deletions

View File

@@ -19,16 +19,16 @@ import unittest
import numpy as np
import transformers
from transformers import GPT2Config, is_flax_available, is_torch_available
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
from jax import lax
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester:
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester:
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester:
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_generation(self, config, input_ids):
prompt_length = 3
model = FlaxGPT2LMHeadModel(config)
max_length = 10
batch_size = 1
prompt_ids = input_ids[:1, :prompt_length]
# put all generation logic into one function
def generate(prompt_ids):
def first_pass(prompt_ids):
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
next_token = jnp.argmax(logits[:, -1:], axis=-1)
return next_token, cache
def greedy_search_cond_fn(state):
cur_len, _, _, _ = state
return ~(cur_len == max_length - 1)
def greedy_search_body_fn(state):
cur_len, sequences, current_token, cache = state
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
next_token = jnp.argmax(next_logits, axis=-1)
return cur_len + 1, next_sequences, next_token, next_cache
# init tensor to be filled with generation result
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
# init past key values for cache
past_key_values = model.init_cache(batch_size, max_length)
# first pass with long prompt
next_token, cache = first_pass(prompt_ids)
# prepare state for generation loop
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
# fast generation
_, output_sequences, final_token, _ = lax.while_loop(
greedy_search_cond_fn, greedy_search_body_fn, init_state
)
# append last token
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
return output_sequences
jit_generate = jax.jit(generate)
output_sequences = jit_generate(prompt_ids)
self.parent.assertEqual(output_sequences.shape, (1, max_length))
@require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxGPT2ModelTester(self)
@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
model_class_name, config, input_ids, attention_mask
)
def test_use_cache_generation(self):
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_generation(config, input_ids)
@slow
def test_batch_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
model.do_sample = False
model.config.pad_token_id = model.config.eos_token_id
jit_generate = jax.jit(model.generate)
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [
"Hello this is a long string of words. I'm going to try to explain what I mean.",
"Hey, I'm not sure if I'm going to be able to do",
]
self.assertListEqual(output_string, expected_string)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently