[s2s]Use prepare_translation_batch for Marian finetuning (#6293)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-08-06 14:58:38 -04:00
committed by GitHub
parent 2f2aa0c89c
commit 2804fff839
5 changed files with 22 additions and 12 deletions

View File

@@ -127,10 +127,12 @@ class MarianTokenizer(PreTrainedTokenizer):
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
pad_to_max_length: bool = True,
return_tensors: str = "pt",
truncation_strategy="only_first",
padding="longest",
**unused,
) -> BatchEncoding:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
@@ -162,6 +164,9 @@ class MarianTokenizer(PreTrainedTokenizer):
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
self.current_spm = self.spm_target
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():