[s2s]Use prepare_translation_batch for Marian finetuning (#6293)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -127,10 +127,12 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
pad_to_max_length: bool = True,
|
||||
return_tensors: str = "pt",
|
||||
truncation_strategy="only_first",
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
Arguments:
|
||||
@@ -162,6 +164,9 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
self.current_spm = self.spm_target
|
||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
||||
for k, v in decoder_inputs.items():
|
||||
|
||||
Reference in New Issue
Block a user