OPT - fix docstring and improve tests slighly (#17228)
* correct some stuff * fix doc tests * make style
This commit is contained in:
committed by
Lysandre Debut
parent
219abba24c
commit
f79af76fc1
@@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
import timeout_decorator # noqa
|
||||
|
||||
from transformers import OPTConfig, is_torch_available, pipeline
|
||||
from transformers import OPTConfig, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
@@ -330,33 +330,61 @@ class OPTEmbeddingsTest(unittest.TestCase):
|
||||
assert torch.allclose(logits, logits_meta, atol=1e-4)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class OPTGenerationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.all_model_path = ["facebook/opt-125m", "facebook/opt-350m"]
|
||||
|
||||
def test_generation(self):
|
||||
prompts = [
|
||||
@property
|
||||
def prompts(self):
|
||||
return [
|
||||
"Today is a beautiful day and I want to",
|
||||
"In the city of",
|
||||
"Paris is the capital of France and",
|
||||
"Computers and mobile phones have taken",
|
||||
]
|
||||
NEXT_TOKENS = [3392, 764, 5, 81]
|
||||
GEN_OUTPUT = []
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
|
||||
for model in self.all_model_path:
|
||||
model = OPTForCausalLM.from_pretrained(self.path_model)
|
||||
model = model.eval()
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
def test_generation_pre_attn_layer_norm(self):
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
gen = pipeline("text-generation", model=model, tokenizer=tokenizer, return_tensors=True)
|
||||
EXPECTED_OUTPUTS = [
|
||||
"Today is a beautiful day and I want to thank",
|
||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
||||
"Paris is the capital of France and Parisdylib",
|
||||
"Computers and mobile phones have taken precedence over",
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
len_input_sentence = len(tokenizer.tokenize(prompt))
|
||||
predicted_next_token = gen(prompt)[0]["generated_token_ids"][len_input_sentence]
|
||||
GEN_OUTPUT.append(predicted_next_token)
|
||||
self.assertListEqual(GEN_OUTPUT, NEXT_TOKENS)
|
||||
predicted_outputs = []
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||||
model = OPTForCausalLM.from_pretrained(model_id)
|
||||
|
||||
for prompt in self.prompts:
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
generated_ids = model.generate(input_ids, max_length=10)
|
||||
|
||||
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
predicted_outputs += generated_string
|
||||
|
||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generation_post_attn_layer_norm(self):
|
||||
model_id = "facebook/opt-350m"
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
"Today is a beautiful day and I want to share",
|
||||
"In the city of San Francisco, the city",
|
||||
"Paris is the capital of France and the capital",
|
||||
"Computers and mobile phones have taken over the",
|
||||
]
|
||||
|
||||
predicted_outputs = []
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||||
model = OPTForCausalLM.from_pretrained(model_id)
|
||||
|
||||
for prompt in self.prompts:
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
generated_ids = model.generate(input_ids, max_length=10)
|
||||
|
||||
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
predicted_outputs += generated_string
|
||||
|
||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
Reference in New Issue
Block a user