Generate: contrastive search with full optional outputs (#19963)
* Use beam search functionality; Add extra outputs and test * Add full tests for contrastive search * Add error message on unconventional cache format
This commit is contained in:
@@ -490,3 +490,34 @@ class OPTGenerationTest(unittest.TestCase):
|
||||
self.assertFalse(
|
||||
torch.isnan(outputs.logits[0]).any().item()
|
||||
) # the first logits could contain NaNs if it fails
|
||||
|
||||
@slow
|
||||
def test_contrastive_search_opt(self):
|
||||
article = (
|
||||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
|
||||
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
|
||||
"there?"
|
||||
)
|
||||
|
||||
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
|
||||
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
|
||||
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
|
||||
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
|
||||
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
|
||||
"you lived there?\nStatue: A hundred years.\nHuman: And you’re from what country?\nStatue: The United "
|
||||
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
|
||||
"country.\nHuman: What tyranny?\nStatue: They didn’t let me speak my mind.\nHuman: What was your "
|
||||
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
|
||||
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
|
||||
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
|
||||
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
|
||||
"France.\nHuman: And your parents were French?\nStatue"
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user