fix if use lang embeddings in tf xlm

This commit is contained in:
Patrick von Platen
2020-03-09 11:18:54 +01:00
parent fbd02d4693
commit 4620caa864
2 changed files with 2 additions and 2 deletions

View File

@@ -342,4 +342,4 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase):
] # the president the president the president the president the president the president the president the president the president the president
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids, do_sample=False)