splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests

This commit is contained in:
thomwolf
2019-01-29 10:31:42 +01:00
parent 5456d82311
commit 98c96fb1a7
7 changed files with 66 additions and 44 deletions

View File

@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
use_labels=True,
vocab_size=99,
n_special=1,
n_ctx=33,
n_positions=33,
n_embd=32,
n_layer=5,
n_head=4,
@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
position_ids = None
if self.use_position_ids:
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx)
position_ids = position_ids + self.n_special + self.vocab_size
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.vocab_size + self.n_special
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
mc_labels = None
@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_ctx=self.n_ctx,
n_positions=self.n_positions,
n_special=self.n_special,
n_embd=self.n_embd,
n_layer=self.n_layer,
@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs
def check_openai_lm_head_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs
def check_openai_double_heads_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])