[MarianTokenizer] implement save_vocabulary and other common methods (#4389)

This commit is contained in:
Sam Shleifer
2020-05-19 19:45:49 -04:00
committed by GitHub
parent 956c4c4eb4
commit efbc1c5a9d
3 changed files with 145 additions and 15 deletions

View File

@@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
max_indices = logits.argmax(-1)
self.tokenizer.batch_decode(max_indices)
def test_tokenizer_equivalence(self):
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0].tolist())
def test_unk_support(self):
t = self.tokenizer
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()