fixed all tests, still need to check ctrl tf and pt and xlm tf

This commit is contained in:
patrickvonplaten
2020-03-08 21:45:55 +01:00
parent b4a3a64744
commit fbd02d4693
7 changed files with 51 additions and 49 deletions

View File

@@ -403,28 +403,29 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
input_ids = torch.tensor([[1, 14, 2232, 26, 1]], dtype=torch.long, device=torch_device) # The dog is cute
input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device) # the president
expected_output_ids = [
1,
14,
2232,
26,
1,
567,
26,
32,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)
447,
14,
447,
14,
447,
14,
447,
14,
447,
14,
447,
14,
447,
14,
447,
14,
447,
14,
447,
] # the president the president the president the president the president the president the president the president the president the president
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)