splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests
This commit is contained in:
@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_special=1,
|
||||
n_ctx=33,
|
||||
n_positions=33,
|
||||
n_embd=32,
|
||||
n_layer=5,
|
||||
n_head=4,
|
||||
@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_special = n_special
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
position_ids = None
|
||||
if self.use_position_ids:
|
||||
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx)
|
||||
position_ids = position_ids + self.n_special + self.vocab_size
|
||||
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.vocab_size + self.n_special
|
||||
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||
|
||||
mc_labels = None
|
||||
@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
config = OpenAIGPTConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_ctx=self.n_ctx,
|
||||
n_positions=self.n_positions,
|
||||
n_special=self.n_special,
|
||||
n_embd=self.n_embd,
|
||||
n_layer=self.n_layer,
|
||||
@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_openai_lm_head_output(self, result):
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_openai_double_heads_output(self, result):
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
|
||||
Reference in New Issue
Block a user