committed by
GitHub
parent
48c22691e3
commit
54192058f3
@@ -366,6 +366,47 @@ class OPTGenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
def test_batch_generation(self):
|
||||||
|
model_id = "facebook/opt-350m"
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||||||
|
model = OPTForCausalLM.from_pretrained(model_id)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# use different length sentences to test batching
|
||||||
|
sentences = [
|
||||||
|
"Hello, my dog is a little",
|
||||||
|
"Today, I",
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||||
|
input_ids = inputs["input_ids"].to(torch_device)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||||
|
|
||||||
|
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||||
|
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||||
|
|
||||||
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||||
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output_sentence = [
|
||||||
|
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
|
||||||
|
"Today, I was in the middle of a conversation with a friend about the",
|
||||||
|
]
|
||||||
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||||
|
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
||||||
|
|
||||||
def test_generation_post_attn_layer_norm(self):
|
def test_generation_post_attn_layer_norm(self):
|
||||||
model_id = "facebook/opt-350m"
|
model_id = "facebook/opt-350m"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user