From 6b3438df21782fe5736b267c05415f64a26a65d9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 9 Sep 2019 12:48:36 +0200 Subject: [PATCH] fixing GPT2 double head model and updating the torch version tests --- pytorch_transformers/modeling_gpt2.py | 13 +++-- pytorch_transformers/modeling_tf_gpt2.py | 29 ++++++----- pytorch_transformers/modeling_tf_utils.py | 15 ++++-- pytorch_transformers/modeling_utils.py | 2 +- .../tests/modeling_gpt2_test.py | 36 ++++++++++--- .../tests/modeling_tf_gpt2_test.py | 51 ++++++++++++------- 6 files changed, 98 insertions(+), 48 deletions(-) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 4268641187..324d020fc3 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel): self.h[layer].attn.prune_heads(heads) def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + if past is None: past_length = 0 past = [None] * len(self.h) @@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel): # Attention mask. if attention_mask is not None: + attention_mask = attention_mask.view(-1, input_shape[-1]) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] @@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel): else: head_mask = [None] * self.config.n_layer - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_ids.size(-1)) - position_ids = position_ids.view(-1, position_ids.size(-1)) - inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) token_type_embeds = self.wte(token_type_ids) else: token_type_embeds = 0 diff --git a/pytorch_transformers/modeling_tf_gpt2.py b/pytorch_transformers/modeling_tf_gpt2.py index bcb9f5309a..85873c9d1b 100644 --- a/pytorch_transformers/modeling_tf_gpt2.py +++ b/pytorch_transformers/modeling_tf_gpt2.py @@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer): def _linear(self, inputs): """Computes logits by running inputs through a linear layer. Args: - inputs: A float32 tensor with shape [batch_size, length, hidden_size] + inputs: A float32 tensor with shape [..., hidden_size] Returns: - float32 tensor with shape [batch_size, length, vocab_size]. + float32 tensor with shape [..., vocab_size]. """ - batch_size = shape_list(inputs)[0] - length = shape_list(inputs)[1] + first_dims = shape_list(inputs)[:-1] x = tf.reshape(inputs, [-1, self.hidden_size]) logits = tf.matmul(x, self.weight, transpose_b=True) - return tf.reshape(logits, [batch_size, length, self.vocab_size]) + return tf.reshape(logits, first_dims + [self.vocab_size]) class TFGPT2MainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): @@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): @tf.function def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): - raise ValueError("Inputs should be a list or a dict with at least two elements: 'inputs_ids' and 'mc_token_ids'") + input_ids = inputs + mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None elif isinstance(inputs, (tuple, list)): input_ids = inputs[0] - mc_token_ids = inputs[1] + mc_token_ids = inputs[1] if len(inputs) > 1 else None past = inputs[2] if len(inputs) > 2 else None attention_mask = inputs[3] if len(inputs) > 3 else None token_type_ids = inputs[4] if len(inputs) > 4 else None @@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): assert len(inputs) <= 7, "Too many inputs." else: input_ids = inputs.get('input_ids') - mc_token_ids = inputs.get('mc_token_ids') + mc_token_ids = inputs.get('mc_token_ids', None) past = inputs.get('past', None) attention_mask = inputs.get('attention_mask', None) token_type_ids = inputs.get('token_type_ids', None) @@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): head_mask = inputs.get('head_mask', None) assert len(inputs) <= 5, "Too many inputs." - assert len(shape_list(input_ids)) == 3, "Inputs should have 3 dimensions: batch, choices, sequence length" - num_choices = shape_list(input_ids)[1] - seq_length = shape_list(input_ids)[2] + input_shapes = shape_list(input_ids) + + seq_length = input_shapes[-1] flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None @@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask] - outputs = self.transformer(flat_inputs, training=training) - + transformer_outputs = self.transformer(flat_inputs, training=training) hidden_states = transformer_outputs[0] + hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) + lm_logits = self.transformer.wte(hidden_states, mode="linear") mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training) + mc_logits = tf.squeeze(mc_logits, axis=-1) + outputs = (lm_logits, mc_logits) + transformer_outputs[1:] return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions) diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index af67a8442e..9630e27478 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer): elif self.summary_type == 'mean': output = tf.mean(hidden_states, axis=1) elif self.summary_type == 'cls_index': + hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] if cls_index is None: - cls_index = tf.fill(tf.shape(hidden_states[..., :1, :]), hidden_states.shape[-2]-1, dtype=tf.int32) - else: - cls_index = cls_index[..., tf.newaxis, tf.newaxis] - cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) + cls_index = tf.fill(hidden_shape[:-2], hidden_shape[-2] - 1) # A tensor full of shape [batch] or [batch, num choices] full of sequence length + cls_shape = shape_list(cls_index) + if len(cls_shape) <= len(hidden_shape) - 2: + cls_index = cls_index[..., tf.newaxis] + # else: + # cls_index = cls_index[..., tf.newaxis] + # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) + output = tf.squeeze(output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size) elif self.summary_type == 'attn': raise NotImplementedError diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 2fb4671674..9fd7a2c0c2 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -679,7 +679,7 @@ class SequenceSummary(nn.Module): self.last_dropout = nn.Dropout(config.summary_last_dropout) def forward(self, hidden_states, cls_index=None): - """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. + """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer. cls_index: [optional] position of the classification token if summary_type == 'cls_index', shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. if summary_type == 'cls_index' and cls_index is None: diff --git a/pytorch_transformers/tests/modeling_gpt2_test.py b/pytorch_transformers/tests/modeling_gpt2_test.py index dc7c0d1816..e5accfa8cf 100644 --- a/pytorch_transformers/tests/modeling_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_gpt2_test.py @@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): use_token_type_ids=True, use_input_mask=True, use_labels=True, + use_mc_token_ids=True, vocab_size=99, hidden_size=32, num_hidden_layers=5, @@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): self.use_token_type_ids = use_token_type_ids self.use_input_mask = use_input_mask self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + sequence_labels = None token_labels = None choice_labels = None @@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) - return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels + return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels def check_loss_output(self, result): self.parent.assertListEqual( @@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) - def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args): model = GPT2DoubleHeadsModel(config) model.eval() - loss, lm_logits, mc_logits, _ = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids) + + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + + inputs = {'input_ids': multiple_choice_inputs_ids, + 'mc_token_ids': mc_token_ids, + 'attention_mask': multiple_choice_input_mask, + 'token_type_ids': multiple_choice_token_type_ids, + 'lm_labels': multiple_choice_inputs_ids} + + loss, lm_logits, mc_logits, _ = model(**inputs) result = { "loss": loss, - "lm_logits": lm_logits + "lm_logits": lm_logits, + "mc_logits": mc_logits } self.parent.assertListEqual( @@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): []) self.parent.assertListEqual( list(result["lm_logits"].size()), - [self.batch_size, self.seq_length, self.vocab_size]) + [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]) + self.parent.assertListEqual( + list(result["mc_logits"].size()), + [self.batch_size, self.num_choices]) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs + + (config, input_ids, input_mask, head_mask, token_type_ids, + mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs + inputs_dict = { 'input_ids': input_ids, 'token_type_ids': token_type_ids, diff --git a/pytorch_transformers/tests/modeling_tf_gpt2_test.py b/pytorch_transformers/tests/modeling_tf_gpt2_test.py index 5fef1f6453..490d5c4e32 100644 --- a/pytorch_transformers/tests/modeling_tf_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_tf_gpt2_test.py @@ -37,9 +37,9 @@ else: class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): - # all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, - # TFGPT2DoubleHeadsModel) if is_tf_available() else () - all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else () + all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, + TFGPT2DoubleHeadsModel) if is_tf_available() else () + # all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else () class TFGPT2ModelTester(object): @@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): use_token_type_ids=True, use_input_mask=True, use_labels=True, + use_mc_token_ids=True, vocab_size=99, hidden_size=32, num_hidden_layers=5, @@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): self.use_token_type_ids = use_token_type_ids self.use_input_mask = use_input_mask self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + sequence_labels = None token_labels = None choice_labels = None @@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) - return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels + return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = TFGPT2Model(config=config) @@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): [self.batch_size, self.seq_length, self.vocab_size]) - def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): - pass - # model = TFGPT2DoubleHeadsModel(config=config) - # inputs = {'input_ids': input_ids, - # 'attention_mask': input_mask, - # 'token_type_ids': token_type_ids} - # seq_relationship_score, = model(inputs)[0] - # result = { - # "seq_relationship_score": seq_relationship_score.numpy(), - # } - # self.parent.assertListEqual( - # list(result["seq_relationship_score"].shape), - # [self.batch_size, 2]) + def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args): + model = TFGPT2DoubleHeadsModel(config=config) + + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1)) + + inputs = {'input_ids': multiple_choice_inputs_ids, + 'mc_token_ids': mc_token_ids, + 'attention_mask': multiple_choice_input_mask, + 'token_type_ids': multiple_choice_token_type_ids} + lm_logits, mc_logits = model(inputs)[:2] + result = { + "lm_logits": lm_logits.numpy(), + "mc_logits": mc_logits.numpy() + } + self.parent.assertListEqual( + list(result["lm_logits"].shape), + [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]) + self.parent.assertListEqual( + list(result["mc_logits"].shape), + [self.batch_size, self.num_choices]) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, head_mask, token_type_ids, - sequence_labels, token_labels, choice_labels) = config_and_inputs + mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} return config, inputs_dict