Flax T5 (#12150)
* copy pytorch-t5 * init * boom boom * forward pass same * make generation work * add more tests * make test work * finish normal tests * make fix-copies * finish quality * correct slow example * correct slow test * version table * upload models * Update tests/test_modeling_flax_t5.py * correct incorrectly deleted line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
|
||||
}
|
||||
|
||||
|
||||
class FlaxBartModelTester(unittest.TestCase):
|
||||
class FlaxBartModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
|
||||
513
tests/test_modeling_flax_t5.py
Normal file
513
tests/test_modeling_flax_t5.py
Normal file
File diff suppressed because one or more lines are too long
@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
def tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_small_generation(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||
model.config.max_length = 8
|
||||
model.config.num_beams = 1
|
||||
model.config.do_sample = False
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
|
||||
|
||||
sequences = model.generate(input_ids)
|
||||
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user