[FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340)
* [FLAX] Add dtype to embedding for bert/bart/opt/t5 * Fix all copies * Add a test case
This commit is contained in:
@@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_small_generation_bfloat16(self):
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16)
|
||||
model.config.max_length = 8
|
||||
model.config.num_beams = 1
|
||||
model.config.do_sample = False
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
|
||||
|
||||
sequences = model.generate(input_ids).sequences
|
||||
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||
|
||||
Reference in New Issue
Block a user