fsmt slow test uses lists (#8031)
This commit is contained in:
@@ -144,11 +144,11 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# for src_text, _ in targets: print(f"""[\n"{src_text}",\n {model.encode(src_text).tolist()}\n],""")
|
# for src_text, _ in targets: print(f"""[\n"{src_text}",\n {model.encode(src_text).tolist()}\n],""")
|
||||||
|
|
||||||
for src_text, tgt_input_ids in targets:
|
for src_text, tgt_input_ids in targets:
|
||||||
input_ids = tokenizer_enc.encode(src_text, return_tensors="pt")[0].tolist()
|
encoded_ids = tokenizer_enc.encode(src_text, return_tensors=None)
|
||||||
self.assertListEqual(input_ids, tgt_input_ids)
|
self.assertListEqual(encoded_ids, tgt_input_ids)
|
||||||
|
|
||||||
# and decode backward, using the reversed languages model
|
# and decode backward, using the reversed languages model
|
||||||
decoded_text = tokenizer_dec.decode(input_ids, skip_special_tokens=True)
|
decoded_text = tokenizer_dec.decode(encoded_ids, skip_special_tokens=True)
|
||||||
self.assertEqual(decoded_text, src_text)
|
self.assertEqual(decoded_text, src_text)
|
||||||
|
|
||||||
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
@unittest.skip("FSMTConfig.__init__ requires non-optional args")
|
||||||
|
|||||||
Reference in New Issue
Block a user