Fix GPT2DoubleHeadsModel to work with model.generate() (#6601)
* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used and for GPT2LMHeadModel too * Update tests to check token_type_ids usage in GPT2 models
This commit is contained in:
@@ -144,6 +144,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
input_ids = input_ids.index_select(0, expanded_return_idx)
|
input_ids = input_ids.index_select(0, expanded_return_idx)
|
||||||
|
|
||||||
|
if "token_type_ids" in model_kwargs:
|
||||||
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
|
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
||||||
|
|
||||||
@@ -194,6 +198,11 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
model_kwargs["past"] = None
|
model_kwargs["past"] = None
|
||||||
|
|
||||||
|
# update token_type_ids with last value
|
||||||
|
if "token_type_ids" in model_kwargs:
|
||||||
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
# update attention mask
|
# update attention mask
|
||||||
if not is_encoder_decoder:
|
if not is_encoder_decoder:
|
||||||
if "attention_mask" in model_kwargs:
|
if "attention_mask" in model_kwargs:
|
||||||
|
|||||||
@@ -708,9 +708,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
@@ -729,6 +732,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||||
@@ -836,14 +840,32 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||||
|
|||||||
@@ -469,12 +469,26 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||||
|
input_ids = inputs["input_ids"].to(torch_device)
|
||||||
|
token_type_ids = torch.cat(
|
||||||
|
[
|
||||||
|
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
|
||||||
|
input_ids.new_full((input_ids.shape[0], 1), 500),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
input_ids=inputs["input_ids"].to(torch_device),
|
input_ids=input_ids,
|
||||||
attention_mask=inputs["attention_mask"].to(torch_device),
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
outputs_tt = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||||
|
|
||||||
@@ -483,6 +497,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||||
|
|
||||||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
||||||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||||
|
|
||||||
@@ -491,6 +506,67 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
"Today, I'm going to be doing a lot of research on this. I",
|
"Today, I'm going to be doing a lot of research on this. I",
|
||||||
]
|
]
|
||||||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||||
|
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
||||||
|
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_2heads(self):
|
||||||
|
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
|
||||||
|
model.to(torch_device)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# This tokenizer has no pad token, so we have to set it in some way
|
||||||
|
# Define PAD Token = EOS Token = 50256
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
|
# use different length sentences to test batching
|
||||||
|
sentences = [
|
||||||
|
"Hello, my dog is a little",
|
||||||
|
"Today, I",
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||||
|
input_ids = inputs["input_ids"].to(torch_device)
|
||||||
|
token_type_ids = torch.cat(
|
||||||
|
[
|
||||||
|
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
|
||||||
|
input_ids.new_full((input_ids.shape[0], 1), 500),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_tt = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||||||
|
|
||||||
|
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||||||
|
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||||||
|
|
||||||
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
||||||
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||||||
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output_sentence = [
|
||||||
|
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
|
||||||
|
"Today, I'm going to be doing a lot of research on this. I",
|
||||||
|
]
|
||||||
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||||||
|
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
||||||
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -540,11 +616,23 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
|
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
|
||||||
|
input_ids = tokenized.input_ids.to(torch_device)
|
||||||
output_ids = model.generate(input_ids, do_sample=True)
|
output_ids = model.generate(input_ids, do_sample=True)
|
||||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
token_type_ids = tokenized.token_type_ids.to(torch_device)
|
||||||
|
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
|
||||||
|
output_seq_tt = model.generate(
|
||||||
|
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
|
||||||
|
)
|
||||||
|
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
|
||||||
|
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
|
||||||
|
|
||||||
EXPECTED_OUTPUT_STR = (
|
EXPECTED_OUTPUT_STR = (
|
||||||
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
||||||
)
|
)
|
||||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||||
|
self.assertTrue(
|
||||||
|
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
|
||||||
|
) # token_type_ids should change output
|
||||||
|
|||||||
Reference in New Issue
Block a user