Fix GPT2DoubleHeadsModel to work with model.generate() (#6601)
* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used and for GPT2LMHeadModel too * Update tests to check token_type_ids usage in GPT2 models
This commit is contained in:
@@ -469,12 +469,26 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
token_type_ids = torch.cat(
|
||||
[
|
||||
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
|
||||
input_ids.new_full((input_ids.shape[0], 1), 500),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
)
|
||||
|
||||
outputs_tt = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
@@ -483,6 +497,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
@@ -491,6 +506,67 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"Today, I'm going to be doing a lot of research on this. I",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
def test_batch_generation_2heads(self):
|
||||
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# This tokenizer has no pad token, so we have to set it in some way
|
||||
# Define PAD Token = EOS Token = 50256
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
# use different length sentences to test batching
|
||||
sentences = [
|
||||
"Hello, my dog is a little",
|
||||
"Today, I",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
token_type_ids = torch.cat(
|
||||
[
|
||||
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
|
||||
input_ids.new_full((input_ids.shape[0], 1), 500),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
)
|
||||
|
||||
outputs_tt = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||
|
||||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||
|
||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||
|
||||
expected_output_sentence = [
|
||||
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
|
||||
"Today, I'm going to be doing a lot of research on this. I",
|
||||
]
|
||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||
|
||||
@slow
|
||||
@@ -540,11 +616,23 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
token_type_ids = tokenized.token_type_ids.to(torch_device)
|
||||
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
|
||||
output_seq_tt = model.generate(
|
||||
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
|
||||
)
|
||||
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
|
||||
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = (
|
||||
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
||||
)
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
self.assertTrue(
|
||||
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
|
||||
) # token_type_ids should change output
|
||||
|
||||
Reference in New Issue
Block a user