MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182)
This commit is contained in:
@@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_to_max_length: bool = True,
|
pad_to_max_length: bool = True,
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
|
truncation_strategy="only_first",
|
||||||
|
padding="longest",
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||||
Arguments:
|
Arguments:
|
||||||
@@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||||
self.current_spm = self.spm_source
|
self.current_spm = self.spm_source
|
||||||
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
||||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
tokenizer_kwargs = dict(
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
|
truncation_strategy=truncation_strategy,
|
||||||
|
padding=padding,
|
||||||
)
|
)
|
||||||
|
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||||
|
|
||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
self.current_spm = self.spm_target
|
self.current_spm = self.spm_target
|
||||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
pad_to_max_length=pad_to_max_length,
|
|
||||||
)
|
|
||||||
for k, v in decoder_inputs.items():
|
for k, v in decoder_inputs.items():
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs[f"decoder_{k}"] = v
|
||||||
self.current_spm = self.spm_source
|
self.current_spm = self.spm_source
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
|
|||||||
from transformers.tokenization_utils import BatchEncoding
|
from transformers.tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
from .utils import _torch_available
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
|||||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||||
zh_code = ">>zh<<"
|
zh_code = ">>zh<<"
|
||||||
ORG_NAME = "Helsinki-NLP/"
|
ORG_NAME = "Helsinki-NLP/"
|
||||||
|
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||||
|
|
||||||
|
|
||||||
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
contents = [x.name for x in Path(save_dir).glob("*")]
|
contents = [x.name for x in Path(save_dir).glob("*")]
|
||||||
self.assertIn("source.spm", contents)
|
self.assertIn("source.spm", contents)
|
||||||
MarianTokenizer.from_pretrained(save_dir)
|
MarianTokenizer.from_pretrained(save_dir)
|
||||||
|
|
||||||
|
def test_outputs_not_longer_than_maxlen(self):
|
||||||
|
tok = self.get_tokenizer()
|
||||||
|
|
||||||
|
batch = tok.prepare_translation_batch(
|
||||||
|
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||||
|
|
||||||
|
def test_outputs_can_be_shorter(self):
|
||||||
|
tok = self.get_tokenizer()
|
||||||
|
batch_smaller = tok.prepare_translation_batch(
|
||||||
|
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch_smaller, BatchEncoding)
|
||||||
|
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user