[FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340)

* [FLAX] Add dtype to embedding for bert/bart/opt/t5

* Fix all copies

* Add a test case
This commit is contained in:
Lianmin Zheng
2022-11-28 07:21:42 -08:00
committed by GitHub
parent 667ccea722
commit ac2f6674a3
12 changed files with 41 additions and 0 deletions

View File

@@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_small_generation_bfloat16(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
sequences = model.generate(input_ids).sequences
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_summarization(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")