test ctrl
This commit is contained in:
@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
def test_lm_generate_ctrl(self):
|
def test_lm_generate_ctrl(self):
|
||||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[[11858, 586, 20984, 8]], dtype=torch.long, device=torch_device
|
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||||
) # Legal My neighbor is
|
) # Legal the president is
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
11859,
|
11859,
|
||||||
586,
|
0,
|
||||||
20984,
|
1611,
|
||||||
8,
|
8,
|
||||||
13391,
|
5,
|
||||||
3,
|
150,
|
||||||
980,
|
26449,
|
||||||
8258,
|
|
||||||
72,
|
|
||||||
327,
|
|
||||||
148,
|
|
||||||
2,
|
2,
|
||||||
53,
|
19,
|
||||||
29,
|
348,
|
||||||
226,
|
469,
|
||||||
3,
|
3,
|
||||||
780,
|
2595,
|
||||||
49,
|
48,
|
||||||
3,
|
20740,
|
||||||
980,
|
246533,
|
||||||
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
246533,
|
||||||
|
19,
|
||||||
|
30,
|
||||||
|
5,
|
||||||
|
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, do_sample=False)
|
output_ids = model.generate(input_ids, do_sample=False)
|
||||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_lm_generate_ctrl(self):
|
def test_lm_generate_ctrl(self):
|
||||||
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
||||||
input_ids = tf.convert_to_tensor([[11858, 586, 20984, 8]], dtype=tf.int32)
|
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
||||||
expected_output_ids = [
|
expected_output_ids = [
|
||||||
11859,
|
11859,
|
||||||
586,
|
0,
|
||||||
20984,
|
1611,
|
||||||
8,
|
8,
|
||||||
13391,
|
5,
|
||||||
3,
|
150,
|
||||||
980,
|
26449,
|
||||||
8258,
|
|
||||||
72,
|
|
||||||
327,
|
|
||||||
148,
|
|
||||||
2,
|
2,
|
||||||
53,
|
19,
|
||||||
29,
|
348,
|
||||||
226,
|
469,
|
||||||
3,
|
3,
|
||||||
780,
|
2595,
|
||||||
49,
|
48,
|
||||||
3,
|
20740,
|
||||||
980,
|
246533,
|
||||||
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
246533,
|
||||||
|
19,
|
||||||
|
30,
|
||||||
|
5,
|
||||||
|
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, do_sample=False)
|
output_ids = model.generate(input_ids, do_sample=False)
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user