* copy pytorch-t5

* init

* boom boom

* forward pass same

* make generation work

* add more tests

* make test work

* finish normal tests

* make fix-copies

* finish quality

* correct slow example

* correct slow test

* version table

* upload models

* Update tests/test_modeling_flax_t5.py

* correct incorrectly deleted line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Vasudev Gupta
2021-06-23 17:43:32 +05:30
committed by GitHub
parent 7d4cfa3b47
commit e98233dde1
13 changed files with 2180 additions and 7 deletions

View File

@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
}
class FlaxBartModelTester(unittest.TestCase):
class FlaxBartModelTester:
def __init__(
self,
parent,

File diff suppressed because one or more lines are too long

View File

@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_small_integration_test(self):
"""