Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
committed by
GitHub
parent
b63beb743c
commit
a1bbcf3f6c
@@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
|
||||
src_text = ["Sam"]
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
|
||||
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
|
||||
|
||||
@@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
|
||||
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
@@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||
generated_ids = self.model.generate(**model_inputs)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
@@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_90_generation_from_short_input(self):
|
||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
# generated_txt = self.tokenizer.decode(generated_utterances[0])
|
||||
|
||||
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
|
||||
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
|
||||
assert clean_txt in (
|
||||
"have you ever been to a sam club? it's a great club in the south.",
|
||||
|
||||
Reference in New Issue
Block a user