Refactoring the generate() function (#6949)

* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
This commit is contained in:
Patrick von Platen
2020-11-03 16:04:22 +01:00
committed by GitHub
parent b63beb743c
commit a1bbcf3f6c
38 changed files with 3022 additions and 1191 deletions

View File

@@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = ["Sam"]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
@@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
@@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
generated_ids = self.model.generate(**model_inputs)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
@@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
def test_90_generation_from_short_input(self):
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
generated_utterances = self.model.generate(**model_inputs)
# generated_txt = self.tokenizer.decode(generated_utterances[0])
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
generated_utterances = self.model.generate(**model_inputs)
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
assert clean_txt in (
"have you ever been to a sam club? it's a great club in the south.",