updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -403,7 +403,7 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
input_ids = torch.Tensor([[1, 14, 2232, 26, 1]]).long() # The dog is cute
input_ids = torch.tensor([[1, 14, 2232, 26, 1]], dtype=torch.long, device=torch_device) # The dog is cute
expected_output_ids = [
1,
14,
@@ -426,8 +426,5 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
149,
149,
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
torch.manual_seed(0)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)