fix tests
This commit is contained in:
@@ -93,7 +93,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||||
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float()
|
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
|
||||||
|
|
||||||
config = OpenAIGPTConfig(
|
config = OpenAIGPTConfig(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user