From b12541c4dc70f7fb6cbf4eae79e50d1a9d6a7700 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 9 Mar 2020 13:58:01 +0000 Subject: [PATCH] test ctrl --- tests/test_modeling_ctrl.py | 40 +++++++++++++++++----------------- tests/test_modeling_tf_ctrl.py | 38 ++++++++++++++++---------------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/test_modeling_ctrl.py b/tests/test_modeling_ctrl.py index 506bd2b344..05585f6db3 100644 --- a/tests/test_modeling_ctrl.py +++ b/tests/test_modeling_ctrl.py @@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase): def test_lm_generate_ctrl(self): model = CTRLLMHeadModel.from_pretrained("ctrl") input_ids = torch.tensor( - [[11858, 586, 20984, 8]], dtype=torch.long, device=torch_device - ) # Legal My neighbor is + [[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device + ) # Legal the president is expected_output_ids = [ 11859, - 586, - 20984, + 0, + 1611, 8, - 13391, - 3, - 980, - 8258, - 72, - 327, - 148, + 5, + 150, + 26449, 2, - 53, - 29, - 226, + 19, + 348, + 469, 3, - 780, - 49, - 3, - 980, - ] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay + 2595, + 48, + 20740, + 246533, + 246533, + 19, + 30, + 5, + ] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index d3856a8abd..a04ca7f466 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_ctrl(self): model = TFCTRLLMHeadModel.from_pretrained("ctrl") - input_ids = tf.convert_to_tensor([[11858, 586, 20984, 8]], dtype=tf.int32) + input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is expected_output_ids = [ 11859, - 586, - 20984, + 0, + 1611, 8, - 13391, - 3, - 980, - 8258, - 72, - 327, - 148, + 5, + 150, + 26449, 2, - 53, - 29, - 226, + 19, + 348, + 469, 3, - 780, - 49, - 3, - 980, - ] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay + 2595, + 48, + 20740, + 246533, + 246533, + 19, + 30, + 5, + ] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)