mc_token_mask => mc_token_ids

This commit is contained in:
thomwolf
2019-02-09 16:58:53 +01:00
parent f4a07a392c
commit 1320e4ec0c
3 changed files with 35 additions and 40 deletions

View File

@@ -89,11 +89,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
mc_labels = None
lm_labels = None
mc_token_mask = None
mc_token_ids = None
if self.use_labels:
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float()
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
@@ -109,10 +109,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_mask)
mc_labels, lm_labels, mc_token_ids)
def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_mask):
mc_labels, lm_labels, mc_token_ids):
model = OpenAIGPTModel(config)
model.eval()
hidden_states = model(input_ids, position_ids, token_type_ids)
@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_mask):
mc_labels, lm_labels, mc_token_ids):
model = OpenAIGPTLMHeadModel(config)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
@@ -151,13 +151,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
[])
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_mask):
mc_labels, lm_labels, mc_token_ids):
model = OpenAIGPTDoubleHeadsModel(config)
model.eval()
loss = model(input_ids, mc_token_mask,
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_logits, mc_logits = model(input_ids, mc_token_mask, position_ids=position_ids, token_type_ids=token_type_ids)
lm_logits, mc_logits = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,