From afb50c663a5d5623906ead1e87481926467d59fa Mon Sep 17 00:00:00 2001 From: LSinev Date: Mon, 16 Nov 2020 16:35:44 +0300 Subject: [PATCH] Fix GPT2DoubleHeadsModel to work with model.generate() (#6601) * Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used and for GPT2LMHeadModel too * Update tests to check token_type_ids usage in GPT2 models --- src/transformers/generation_utils.py | 9 +++ src/transformers/modeling_gpt2.py | 22 +++++++ tests/test_modeling_gpt2.py | 92 +++++++++++++++++++++++++++- 3 files changed, 121 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 206658da98..2e3aaa979d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -144,6 +144,10 @@ class GenerationMixin: ) input_ids = input_ids.index_select(0, expanded_return_idx) + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) + if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) @@ -194,6 +198,11 @@ class GenerationMixin: else: model_kwargs["past"] = None + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + # update attention mask if not is_encoder_decoder: if "attention_mask" in model_kwargs: diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 442f78ec43..45b4bebfd6 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -708,9 +708,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -729,6 +732,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, + "token_type_ids": token_type_ids, } @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @@ -836,14 +840,32 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): return self.lm_head def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None return { "input_ids": input_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, } @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index aa6133d35c..6b8fbbbc9f 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -469,12 +469,26 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ] inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + token_type_ids = torch.cat( + [ + input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), + input_ids.new_full((input_ids.shape[0], 1), 500), + ], + dim=-1, + ) outputs = model.generate( - input_ids=inputs["input_ids"].to(torch_device), + input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), ) + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + token_type_ids=token_type_ids, + ) + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) output_non_padded = model.generate(input_ids=inputs_non_padded) @@ -483,6 +497,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) @@ -491,6 +506,67 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): "Today, I'm going to be doing a lot of research on this. I", ] self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + + @slow + def test_batch_generation_2heads(self): + model = GPT2DoubleHeadsModel.from_pretrained("gpt2") + model.to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + tokenizer.padding_side = "left" + + # This tokenizer has no pad token, so we have to set it in some way + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + token_type_ids = torch.cat( + [ + input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), + input_ids.new_full((input_ids.shape[0], 1), 500), + ], + dim=-1, + ) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + token_type_ids=token_type_ids, + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a mess. I'm not sure if he's going", + "Today, I'm going to be doing a lot of research on this. I", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) @slow @@ -540,11 +616,23 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): model.to(torch_device) torch.manual_seed(0) - input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) output_ids = model.generate(input_ids, do_sample=True) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + token_type_ids = tokenized.token_type_ids.to(torch_device) + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq_tt = model.generate( + input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 + ) + output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) + output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) + EXPECTED_OUTPUT_STR = ( "Today is a nice day and if you don't know anything about the state of play during your holiday" ) self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + self.assertTrue( + all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) + ) # token_type_ids should change output