[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)
* Make Seq2Seq Trainer more similar to Trainer * fix typo * fix seq2seq trainer * remove from tests * remove lock * remove train files * delete test files * correct typo * check at init * make sure trainer is not slowed down on TPU * correct isort * remove use cache * fix use cache * add last use chache = false
This commit is contained in:
committed by
GitHub
parent
59b5953d89
commit
3c682ea15c
@@ -16,7 +16,6 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvaluationStrategy
|
||||
from utils import (
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataCollator,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
@@ -138,6 +137,10 @@ class DataTrainingArguments:
|
||||
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -223,7 +226,7 @@ def main():
|
||||
freeze_params(model.get_encoder())
|
||||
assert_all_frozen(model.get_encoder())
|
||||
|
||||
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
dataset_class = Seq2SeqDataset
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
|
||||
Reference in New Issue
Block a user