fixed all tests, still need to check ctrl tf and pt and xlm tf

This commit is contained in:
patrickvonplaten
2020-03-08 21:45:55 +01:00
parent b4a3a64744
commit fbd02d4693
7 changed files with 51 additions and 49 deletions

View File

@@ -573,4 +573,4 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
# TODO: add this test when trasnfo-xl-lmhead is implemented
with self.assertRaises(NotImplementedError):
model.generate(input_ids, max_length=200, do_sample=False)
# self.assertListEqual(output_ids[0].tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented