test ctrl
This commit is contained in:
@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||
input_ids = torch.tensor(
|
||||
[[11858, 586, 20984, 8]], dtype=torch.long, device=torch_device
|
||||
) # Legal My neighbor is
|
||||
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||
) # Legal the president is
|
||||
expected_output_ids = [
|
||||
11859,
|
||||
586,
|
||||
20984,
|
||||
0,
|
||||
1611,
|
||||
8,
|
||||
13391,
|
||||
3,
|
||||
980,
|
||||
8258,
|
||||
72,
|
||||
327,
|
||||
148,
|
||||
5,
|
||||
150,
|
||||
26449,
|
||||
2,
|
||||
53,
|
||||
29,
|
||||
226,
|
||||
19,
|
||||
348,
|
||||
469,
|
||||
3,
|
||||
780,
|
||||
49,
|
||||
3,
|
||||
980,
|
||||
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
||||
2595,
|
||||
48,
|
||||
20740,
|
||||
246533,
|
||||
246533,
|
||||
19,
|
||||
30,
|
||||
5,
|
||||
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
||||
input_ids = tf.convert_to_tensor([[11858, 586, 20984, 8]], dtype=tf.int32)
|
||||
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
||||
expected_output_ids = [
|
||||
11859,
|
||||
586,
|
||||
20984,
|
||||
0,
|
||||
1611,
|
||||
8,
|
||||
13391,
|
||||
3,
|
||||
980,
|
||||
8258,
|
||||
72,
|
||||
327,
|
||||
148,
|
||||
5,
|
||||
150,
|
||||
26449,
|
||||
2,
|
||||
53,
|
||||
29,
|
||||
226,
|
||||
19,
|
||||
348,
|
||||
469,
|
||||
3,
|
||||
780,
|
||||
49,
|
||||
3,
|
||||
980,
|
||||
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
||||
2595,
|
||||
48,
|
||||
20740,
|
||||
246533,
|
||||
246533,
|
||||
19,
|
||||
30,
|
||||
5,
|
||||
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user