[s2s] add config params like Dropout in Seq2SeqTrainingArguments (#7532)
This commit is contained in:
@@ -269,7 +269,11 @@ class Seq2SeqDataCollator:
|
||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
|
||||
if data_args.src_lang is not None:
|
||||
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
||||
if data_args.tgt_lang is not None:
|
||||
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
@@ -310,14 +314,12 @@ class Seq2SeqDataCollator:
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.data_args.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.data_args.tgt_lang,
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user