Fix attn mask gpt2 when using past (#3033)
* fix issue and add some tests * fix issue and add some tests * updated doc string gpt2
This commit is contained in:
committed by
GitHub
parent
9cda3620b6
commit
fdd61b1992
@@ -276,14 +276,17 @@ GPT2_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
GPT2_INPUTS_DOCSTRING = r"""
|
GPT2_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||||
|
`input_ids_length` = `sequence_length if `past` is None else 1
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
If using `past` as an input make sure that `input_ids` are those of the last position.
|
||||||
|
|
||||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
@@ -294,10 +297,12 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
`input_ids_length` = `sequence_length if `past` is None else 1
|
||||||
Segment token indices to indicate first and second portions of the inputs.
|
Segment token indices to indicate first and second portions of the inputs.
|
||||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||||
corresponds to a `sentence B` token
|
corresponds to a `sentence B` token
|
||||||
|
If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position.
|
||||||
|
|
||||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -419,7 +424,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.view(-1, input_shape[-1])
|
batch_size = input_ids.shape[0]
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
|||||||
0
tests/test_modeling_common.py!
Normal file
0
tests/test_modeling_common.py!
Normal file
@@ -170,6 +170,72 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
||||||
|
|
||||||
|
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
|
model = GPT2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
output, past = model(input_ids, token_type_ids=token_type_ids)
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and token_type_ids
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past, _ = model(next_input_ids, token_type_ids=next_token_type_ids)
|
||||||
|
output_from_past, _ = model(next_tokens, token_type_ids=next_token_types, past=past)
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def create_and_check_gpt2_model_attention_mask_past(
|
||||||
|
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||||
|
):
|
||||||
|
model = GPT2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape).long()
|
||||||
|
half_seq_length = self.seq_length // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
output, past = model(input_ids, attention_mask=attn_mask)
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||||
|
output_from_past, _ = model(next_tokens, past=past, attention_mask=attn_mask)
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = GPT2LMHeadModel(config)
|
model = GPT2LMHeadModel(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -248,6 +314,14 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_gpt2_model(*config_and_inputs)
|
self.model_tester.create_and_check_gpt2_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_gpt2_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_gpt2_model_past(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_gpt2_model_att_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
def test_gpt2_lm_head_model(self):
|
def test_gpt2_lm_head_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user