From 6726416e4a9780e7a92b5681e1446f15f7ef83d3 Mon Sep 17 00:00:00 2001 From: Teven Date: Thu, 2 Jul 2020 11:56:44 +0200 Subject: [PATCH] Changed expected_output_ids in TransfoXL generation test (#5462) * Changed expected_output_ids in TransfoXL generation test to match #4826 generation PR. * making black happy * making isort happy --- tests/test_modeling_transfo_xl.py | 103 ++++++++++++++++++------------ 1 file changed, 62 insertions(+), 41 deletions(-) diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index db552b2798..73a8036f24 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -162,7 +162,6 @@ class TransfoXLModelTester: @require_torch class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () test_pruning = False @@ -448,7 +447,6 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): # father initially slaps him for making such an accusation , Rasputin watches as the # man is chased outside and beaten . Twenty years later , Rasputin sees a vision of # the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous , - # with people , even a bishop , begging for his blessing . expected_output_ids = [ @@ -595,54 +593,77 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): 0, 33, 1, - 1857, + 142, + 1298, + 188, 2, - 1, - 1009, + 29546, + 113, + 8, + 3654, 4, + 1, 1109, - 11739, - 4762, - 358, - 5, - 25, - 245, - 28, - 1110, + 7136, + 833, 3, 13, - 1041, + 1645, 4, + 29546, + 11, + 104, + 7, + 1, + 1109, + 532, + 7129, + 2, + 10, + 83507, + 2, + 1162, + 1123, + 2, + 6, + 7245, + 10, + 2, + 5, + 11, + 104, + 7, + 1, + 1109, + 532, + 7129, + 2, + 10, 24, - 603, - 490, - 2, - 71477, - 20098, - 104447, - 2, - 20961, - 1, - 2604, + 24, + 10, + 22, + 10, + 13, + 770, + 5863, 4, - 1, - 329, - 3, - 0, + 7245, + 10, ] - # In 1991, the remains of Russian Tsar Nicholas II and his family ( - # except for Alexei and Maria ) are discovered. The voice of young son, - # Tsarevich Alexei Nikolaevich, narrates the remainder of the story. - # 1883 Western Siberia, a young Grigori Rasputin is asked by his father - # and a group of men to perform magic. Rasputin has a vision and - # denounces one of the men as a horse thief. Although his father initially - # slaps him for making such an accusation, Rasputin watches as the man - # is chased outside and beaten. Twenty years later, Rasputin sees a vision - # of the Virgin Mary, prompting him to become a priest. - # Rasputin quickly becomes famous, with people, even a bishop, begging for - # his blessing. In the 1990s, the remains of Russian Tsar - # Nicholas II and his family were discovered. The voice of young son, - # Tsarevich Alexei Nikolaevich, narrates the remainder of the story. + # In 1991, the remains of Russian Tsar Nicholas II and his family ( except for + # Alexei and Maria ) are discovered. The voice of young son, Tsarevich Alexei + # Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, a young + # Grigori Rasputin is asked by his father and a group of men to perform magic. + # Rasputin has a vision and denounces one of the men as a horse thief. Although + # his father initially slaps him for making such an accusation, Rasputin watches + # as the man is chased outside and beaten. Twenty years later, Rasputin sees a + # vision of the Virgin Mary, prompting him to become a priest. Rasputin quickly + # becomes famous, with people, even a bishop, begging for his blessing. In the + # early 20th century, Rasputin became a symbol of the Russian Orthodox Church. + # The image of Rasputin was used in the Russian national anthem, " Nearer, My God, + # to Heaven ", and was used in the Russian national anthem, " " ( " The Great Spirit + # of Heaven " output_ids = model.generate(input_ids, max_length=200, do_sample=False) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)