updated all tests
This commit is contained in:
@@ -218,7 +218,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_transfo_xl_wt103(self):
|
||||
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||
input_ids = torch.Tensor(
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
33,
|
||||
@@ -363,8 +363,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
24,
|
||||
0,
|
||||
]
|
||||
]
|
||||
).long()
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||
# ( except for Alexei and Maria ) are discovered .
|
||||
# The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
|
||||
@@ -545,14 +547,23 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
28,
|
||||
1110,
|
||||
3,
|
||||
57,
|
||||
629,
|
||||
38,
|
||||
3493,
|
||||
47,
|
||||
1094,
|
||||
7,
|
||||
1297,
|
||||
13,
|
||||
1041,
|
||||
4,
|
||||
24,
|
||||
603,
|
||||
490,
|
||||
2,
|
||||
71477,
|
||||
20098,
|
||||
104447,
|
||||
2,
|
||||
20961,
|
||||
1,
|
||||
2604,
|
||||
4,
|
||||
1,
|
||||
329,
|
||||
3,
|
||||
0,
|
||||
]
|
||||
@@ -566,10 +577,9 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
# is chased outside and beaten. Twenty years later, Rasputin sees a vision
|
||||
# of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for
|
||||
# his blessing. Rasputin first appears as a priest in 1996, in the same year
|
||||
# that the remains of Russian Tsar Nicholas II and his family were discovered. H
|
||||
# his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar
|
||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200)
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user