updated all tests
This commit is contained in:
@@ -123,7 +123,15 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
@@ -139,7 +147,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
result = {"sequence_output": sequence_output}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
@@ -153,7 +161,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
@@ -167,7 +175,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -181,7 +189,11 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask}
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -215,30 +227,29 @@ class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_openai_gpt(self):
|
||||
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
|
||||
input_ids = torch.Tensor([[481, 2585, 544, 4957]]).long() # The dog is cute
|
||||
input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is
|
||||
expected_output_ids = [
|
||||
481,
|
||||
2585,
|
||||
4735,
|
||||
544,
|
||||
4957,
|
||||
669,
|
||||
512,
|
||||
761,
|
||||
5990,
|
||||
271,
|
||||
645,
|
||||
246,
|
||||
963,
|
||||
870,
|
||||
762,
|
||||
239,
|
||||
244,
|
||||
40477,
|
||||
244,
|
||||
249,
|
||||
719,
|
||||
881,
|
||||
487,
|
||||
535,
|
||||
976,
|
||||
2479,
|
||||
544,
|
||||
240,
|
||||
487,
|
||||
804,
|
||||
1296,
|
||||
2891,
|
||||
512,
|
||||
] # the dog is cute when you're annoyed : if he's really stupid, he 'll stop fighting you
|
||||
torch.manual_seed(0)
|
||||
244,
|
||||
603,
|
||||
481,
|
||||
] # the president is a very good man. " \n " i\'m sure he is, " said the
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user