fixing GPT2 double head model and updating the torch version tests
This commit is contained in:
@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(-1, input_shape[-1])
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.n_layer
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
|
||||
@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
inputs: A float32 tensor with shape [..., hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
float32 tensor with shape [..., vocab_size].
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
first_dims = shape_list(inputs)[:-1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.weight, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
return tf.reshape(logits, first_dims + [self.vocab_size])
|
||||
|
||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
@tf.function
|
||||
def call(self, inputs, training=False):
|
||||
if not isinstance(inputs, (dict, tuple, list)):
|
||||
raise ValueError("Inputs should be a list or a dict with at least two elements: 'inputs_ids' and 'mc_token_ids'")
|
||||
input_ids = inputs
|
||||
mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
mc_token_ids = inputs[1]
|
||||
mc_token_ids = inputs[1] if len(inputs) > 1 else None
|
||||
past = inputs[2] if len(inputs) > 2 else None
|
||||
attention_mask = inputs[3] if len(inputs) > 3 else None
|
||||
token_type_ids = inputs[4] if len(inputs) > 4 else None
|
||||
@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs.get('input_ids')
|
||||
mc_token_ids = inputs.get('mc_token_ids')
|
||||
mc_token_ids = inputs.get('mc_token_ids', None)
|
||||
past = inputs.get('past', None)
|
||||
attention_mask = inputs.get('attention_mask', None)
|
||||
token_type_ids = inputs.get('token_type_ids', None)
|
||||
@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
head_mask = inputs.get('head_mask', None)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
|
||||
assert len(shape_list(input_ids)) == 3, "Inputs should have 3 dimensions: batch, choices, sequence length"
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
input_shapes = shape_list(input_ids)
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
|
||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
|
||||
outputs = self.transformer(flat_inputs, training=training)
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
|
||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
||||
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
||||
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
||||
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
|
||||
|
||||
@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
||||
elif self.summary_type == 'mean':
|
||||
output = tf.mean(hidden_states, axis=1)
|
||||
elif self.summary_type == 'cls_index':
|
||||
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
|
||||
if cls_index is None:
|
||||
cls_index = tf.fill(tf.shape(hidden_states[..., :1, :]), hidden_states.shape[-2]-1, dtype=tf.int32)
|
||||
else:
|
||||
cls_index = cls_index[..., tf.newaxis, tf.newaxis]
|
||||
cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
||||
cls_index = tf.fill(hidden_shape[:-2], hidden_shape[-2] - 1) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
|
||||
cls_shape = shape_list(cls_index)
|
||||
if len(cls_shape) <= len(hidden_shape) - 2:
|
||||
cls_index = cls_index[..., tf.newaxis]
|
||||
# else:
|
||||
# cls_index = cls_index[..., tf.newaxis]
|
||||
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
||||
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||
output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
|
||||
output = tf.squeeze(output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size)
|
||||
elif self.summary_type == 'attn':
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
|
||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||
|
||||
def forward(self, hidden_states, cls_index=None):
|
||||
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
||||
""" hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
|
||||
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
||||
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
||||
if summary_type == 'cls_index' and cls_index is None:
|
||||
|
||||
@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
use_token_type_ids=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
mc_token_ids = None
|
||||
if self.use_mc_token_ids:
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
|
||||
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_double_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args):
|
||||
model = GPT2DoubleHeadsModel(config)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits, mc_logits, _ = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids)
|
||||
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
|
||||
inputs = {'input_ids': multiple_choice_inputs_ids,
|
||||
'mc_token_ids': mc_token_ids,
|
||||
'attention_mask': multiple_choice_input_mask,
|
||||
'token_type_ids': multiple_choice_token_type_ids,
|
||||
'lm_labels': multiple_choice_inputs_ids}
|
||||
|
||||
loss, lm_logits, mc_logits, _ = model(**inputs)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits
|
||||
"lm_logits": lm_logits,
|
||||
"mc_logits": mc_logits
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
[self.batch_size, self.num_choices, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["mc_logits"].size()),
|
||||
[self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
|
||||
(config, input_ids, input_mask, head_mask, token_type_ids,
|
||||
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
'input_ids': input_ids,
|
||||
'token_type_ids': token_type_ids,
|
||||
|
||||
@@ -37,9 +37,9 @@ else:
|
||||
|
||||
class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
|
||||
# TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
|
||||
TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
||||
|
||||
class TFGPT2ModelTester(object):
|
||||
|
||||
@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
use_token_type_ids=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
mc_token_ids = None
|
||||
if self.use_mc_token_ids:
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return config, input_ids, input_mask, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
|
||||
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2Model(config=config)
|
||||
@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
pass
|
||||
# model = TFGPT2DoubleHeadsModel(config=config)
|
||||
# inputs = {'input_ids': input_ids,
|
||||
# 'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids}
|
||||
# seq_relationship_score, = model(inputs)[0]
|
||||
# result = {
|
||||
# "seq_relationship_score": seq_relationship_score.numpy(),
|
||||
# }
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["seq_relationship_score"].shape),
|
||||
# [self.batch_size, 2])
|
||||
def create_and_check_gpt2_double_head(self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args):
|
||||
model = TFGPT2DoubleHeadsModel(config=config)
|
||||
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
|
||||
inputs = {'input_ids': multiple_choice_inputs_ids,
|
||||
'mc_token_ids': mc_token_ids,
|
||||
'attention_mask': multiple_choice_input_mask,
|
||||
'token_type_ids': multiple_choice_token_type_ids}
|
||||
lm_logits, mc_logits = model(inputs)[:2]
|
||||
result = {
|
||||
"lm_logits": lm_logits.numpy(),
|
||||
"mc_logits": mc_logits.numpy()
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].shape),
|
||||
[self.batch_size, self.num_choices, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["mc_logits"].shape),
|
||||
[self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
(config, input_ids, input_mask, head_mask, token_type_ids,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user