Flax Generate (#11777)
* fix_torch_device_generate_test * remove @ * add * indexing * correct a couple of tests * fix tests * add logits processor * finish top_k, top_p, temp * add docs * correct flax prng key default * improve generate * add generation docs * add docs * make style * revert model outputs change * make style * correct typo * fix tests * fix slow test * add raise * finish generation Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2df546918e
commit
996a315e76
@@ -19,16 +19,16 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Config, is_flax_available, is_torch_available
|
||||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester:
|
||||
model = model_class_name(config)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
|
||||
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
|
||||
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester:
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
|
||||
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask_cache,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
|
||||
input_ids[:, -1:],
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
attention_mask=attention_mask_cache,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester:
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_generation(self, config, input_ids):
|
||||
prompt_length = 3
|
||||
model = FlaxGPT2LMHeadModel(config)
|
||||
max_length = 10
|
||||
batch_size = 1
|
||||
|
||||
prompt_ids = input_ids[:1, :prompt_length]
|
||||
|
||||
# put all generation logic into one function
|
||||
def generate(prompt_ids):
|
||||
def first_pass(prompt_ids):
|
||||
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
|
||||
next_token = jnp.argmax(logits[:, -1:], axis=-1)
|
||||
return next_token, cache
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
cur_len, _, _, _ = state
|
||||
return ~(cur_len == max_length - 1)
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
cur_len, sequences, current_token, cache = state
|
||||
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
|
||||
|
||||
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
|
||||
next_token = jnp.argmax(next_logits, axis=-1)
|
||||
|
||||
return cur_len + 1, next_sequences, next_token, next_cache
|
||||
|
||||
# init tensor to be filled with generation result
|
||||
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
|
||||
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
|
||||
|
||||
# init past key values for cache
|
||||
past_key_values = model.init_cache(batch_size, max_length)
|
||||
|
||||
# first pass with long prompt
|
||||
next_token, cache = first_pass(prompt_ids)
|
||||
|
||||
# prepare state for generation loop
|
||||
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
|
||||
|
||||
# fast generation
|
||||
_, output_sequences, final_token, _ = lax.while_loop(
|
||||
greedy_search_cond_fn, greedy_search_body_fn, init_state
|
||||
)
|
||||
|
||||
# append last token
|
||||
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
|
||||
|
||||
return output_sequences
|
||||
|
||||
jit_generate = jax.jit(generate)
|
||||
output_sequences = jit_generate(prompt_ids)
|
||||
self.parent.assertEqual(output_sequences.shape, (1, max_length))
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxGPT2ModelTester(self)
|
||||
@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
def test_use_cache_generation(self):
|
||||
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_generation(config, input_ids)
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.do_sample = False
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
jit_generate = jax.jit(model.generate)
|
||||
|
||||
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
|
||||
|
||||
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||
|
||||
expected_string = [
|
||||
"Hello this is a long string of words. I'm going to try to explain what I mean.",
|
||||
"Hey, I'm not sure if I'm going to be able to do",
|
||||
]
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
|
||||
Reference in New Issue
Block a user