add special tokens to gpt-2
This commit is contained in:
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_special=1,
|
||||
n_positions=33,
|
||||
n_embd=32,
|
||||
n_layer=5,
|
||||
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_special = n_special
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
|
||||
total_num_tokens = self.vocab_size + self.n_special
|
||||
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
|
||||
|
||||
position_ids = None
|
||||
if self.use_position_ids:
|
||||
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
|
||||
config = GPT2Config(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_special=self.n_special,
|
||||
n_positions=self.n_positions,
|
||||
n_embd=self.n_embd,
|
||||
n_layer=self.n_layer,
|
||||
@@ -130,7 +134,7 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_gpt2_lm_head_output(self, result):
|
||||
total_voc = self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
@@ -157,7 +161,7 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_gpt2_double_heads_output(self, result):
|
||||
total_voc = self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
|
||||
Reference in New Issue
Block a user