[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)

* Make Seq2Seq Trainer more similar to Trainer

* fix typo

* fix seq2seq trainer

* remove from tests

* remove lock

* remove train files

* delete test files

* correct typo

* check at init

* make sure trainer is not slowed down on TPU

* correct isort

* remove use cache

* fix use cache

* add last use chache = false
This commit is contained in:
Patrick von Platen
2020-10-23 23:05:51 +02:00
committed by GitHub
parent 59b5953d89
commit 3c682ea15c
3 changed files with 185 additions and 43 deletions

View File

@@ -16,7 +16,6 @@ from transformers import (
)
from transformers.trainer_utils import EvaluationStrategy
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset,
assert_all_frozen,
@@ -138,6 +137,10 @@ class DataTrainingArguments:
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
)
def main():
@@ -223,7 +226,7 @@ def main():
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
dataset_class = Seq2SeqDataset
# Get datasets
train_dataset = (