MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182)
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .utils import _torch_available
|
||||
|
||||
|
||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||
zh_code = ">>zh<<"
|
||||
ORG_NAME = "Helsinki-NLP/"
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
|
||||
|
||||
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
contents = [x.name for x in Path(save_dir).glob("*")]
|
||||
self.assertIn("source.spm", contents)
|
||||
MarianTokenizer.from_pretrained(save_dir)
|
||||
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
tok = self.get_tokenizer()
|
||||
|
||||
batch = tok.prepare_translation_batch(
|
||||
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||
|
||||
def test_outputs_can_be_shorter(self):
|
||||
tok = self.get_tokenizer()
|
||||
batch_smaller = tok.prepare_translation_batch(
|
||||
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertIsInstance(batch_smaller, BatchEncoding)
|
||||
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
|
||||
|
||||
Reference in New Issue
Block a user