updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -517,7 +517,7 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
input_ids = torch.Tensor(
input_ids = torch.tensor(
[
[
67,
@@ -682,8 +682,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
4,
3,
]
]
).long()
],
dtype=torch.long,
device=torch_device,
)
# In 1991, the remains of Russian Tsar Nicholas II and his family
# (except for Alexei and Maria) are discovered.
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
@@ -876,26 +878,36 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
22,
2771,
4901,
25,
18,
2059,
20,
24,
303,
1775,
691,
9,
1147,
69,
27,
50,
551,
22,
2771,
4901,
19,
634,
19,
43,
51,
54,
6157,
2999,
33,
4185,
21,
45,
668,
21,
18,
416,
41,
1499,
22,
755,
18,
14285,
9,
12943,
4354,
153,
27,
1499,
22,
642,
22,
]
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
@@ -905,11 +917,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
# 1990, a priest who cannot even walk with his wife, Maria, is asked to perform magic
# in the presence of a local religious leader.
# Since, however, he has had difficulty walking with Maria
# <sep><cls>, Rasputin is asked to perform magic.
# He is not able to perform magic, and his father and
# the men are forced to leave the monastery. Rasputin is forced to return to
torch.manual_seed(0)
output_ids = model.generate(input_ids, max_length=200)
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)