update tests
This commit is contained in:
@@ -88,13 +88,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||||
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||||
|
|
||||||
multiple_choice_labels = None
|
mc_labels = None
|
||||||
lm_labels = None
|
lm_labels = None
|
||||||
multiple_choice_token_mask = None
|
mc_token_mask = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||||
multiple_choice_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
|
mc_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
|
||||||
|
|
||||||
config = OpenAIGPTConfig(
|
config = OpenAIGPTConfig(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
@@ -110,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
return (config, input_ids, token_type_ids, position_ids,
|
return (config, input_ids, token_type_ids, position_ids,
|
||||||
multiple_choice_labels, lm_labels, multiple_choice_token_mask)
|
mc_labels, lm_labels, mc_token_mask)
|
||||||
|
|
||||||
def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
|
def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
|
||||||
multiple_choice_labels, lm_labels, multiple_choice_token_mask):
|
mc_labels, lm_labels, mc_token_mask):
|
||||||
model = OpenAIGPTModel(config)
|
model = OpenAIGPTModel(config)
|
||||||
hidden_states = model(input_ids, position_ids, token_type_ids)
|
hidden_states = model(input_ids, position_ids, token_type_ids)
|
||||||
outputs = {
|
outputs = {
|
||||||
@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
||||||
multiple_choice_labels, lm_labels, multiple_choice_token_mask):
|
mc_labels, lm_labels, mc_token_mask):
|
||||||
model = OpenAIGPTLMHeadModel(config)
|
model = OpenAIGPTLMHeadModel(config)
|
||||||
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||||
lm_logits = model(input_ids, position_ids, token_type_ids)
|
lm_logits = model(input_ids, position_ids, token_type_ids)
|
||||||
@@ -150,15 +150,16 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
[])
|
[])
|
||||||
|
|
||||||
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
||||||
multiple_choice_labels, lm_labels, multiple_choice_token_mask):
|
mc_labels, lm_labels, mc_token_mask):
|
||||||
model = OpenAIGPTDoubleHeadsModel(config)
|
model = OpenAIGPTDoubleHeadsModel(config)
|
||||||
loss = model(input_ids, multiple_choice_token_mask, position_ids,
|
loss = model(input_ids, mc_token_mask,
|
||||||
token_type_ids, lm_labels, multiple_choice_labels)
|
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||||
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask, position_ids, token_type_ids)
|
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||||
|
lm_logits, mc_logits = model(input_ids, mc_token_mask, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||||
outputs = {
|
outputs = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"lm_logits": lm_logits,
|
"lm_logits": lm_logits,
|
||||||
"multiple_choice_logits": multiple_choice_logits,
|
"mc_logits": mc_logits,
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -168,7 +169,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
list(result["lm_logits"].size()),
|
list(result["lm_logits"].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["multiple_choice_logits"].size()),
|
list(result["mc_logits"].size()),
|
||||||
[self.batch_size, self.n_choices])
|
[self.batch_size, self.n_choices])
|
||||||
|
|
||||||
def check_openai_double_heads_loss_output(self, result):
|
def check_openai_double_heads_loss_output(self, result):
|
||||||
|
|||||||
Reference in New Issue
Block a user