[s2s] add config params like Dropout in Seq2SeqTrainingArguments (#7532)

This commit is contained in:
Suraj Patil
2020-10-04 22:12:30 +05:30
committed by GitHub
parent 9bdce3a4f9
commit 99cb924bfb
4 changed files with 35 additions and 18 deletions

View File

@@ -269,7 +269,11 @@ class Seq2SeqDataCollator:
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
if data_args.src_lang is not None:
self.dataset_kwargs["src_lang"] = data_args.src_lang
if data_args.tgt_lang is not None:
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
@@ -310,14 +314,12 @@ class Seq2SeqDataCollator:
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
**self.dataset_kwargs,
)
return batch_encoding.data