Seq2seq trainer (#9241)

* Add label smoothing in Trainer

* Add options for scheduler and Adafactor in Trainer

* Put Seq2SeqTrainer in the main lib

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments and adapt scripts

* Documentation

* Move test not using script to tests folder

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-12-22 11:33:44 -05:00
committed by GitHub
parent 1fc7119181
commit 490b39e614
20 changed files with 655 additions and 166 deletions

View File

@@ -20,9 +20,16 @@ from dataclasses import dataclass, field
from typing import Optional
import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import (
@@ -86,14 +93,14 @@ class DataTrainingArguments:
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
max_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
eval_max_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
@@ -101,13 +108,6 @@ class DataTrainingArguments:
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
@@ -233,7 +233,7 @@ def main():
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_target_length,
max_target_length=data_args.max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@@ -246,7 +246,7 @@ def main():
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_val,
max_target_length=data_args.val_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@@ -259,7 +259,7 @@ def main():
type_path="test",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.test_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@@ -273,13 +273,12 @@ def main():
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=compute_metrics_fn,
data_args=data_args,
tokenizer=tokenizer,
)
all_metrics = {}
@@ -310,7 +309,9 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="val")
metrics = trainer.evaluate(
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
@@ -322,7 +323,12 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.eval_max_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test

View File

@@ -17,8 +17,7 @@ import sys
import unittest
from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
@@ -31,8 +30,7 @@ from transformers.testing_utils import (
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .finetune_trainer import main
set_seed(42)
@@ -120,119 +118,6 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_generations.txt" in contents
assert "test_results.json" in contents
@slow
def test_finetune_bert2bert(self):
if not is_datasets_available():
return
import datasets
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4
def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask
assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])
return batch
def _compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
output_dir = self.get_auto_remove_tmp_dir()
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluation_strategy="steps",
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
# start training
trainer.train()
def run_trainer(
self,
eval_steps: int,
@@ -252,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--max_length {max_len}
--eval_max_length {max_len}
--do_train
--do_eval
--do_predict

View File

@@ -29,7 +29,7 @@ python finetune_trainer.py \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 3000 --eval_steps 3000 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --logging_first_step \

View File

@@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
--num_train_epochs=6 \
--save_steps 500 --eval_steps 500 \
--logging_first_step --logging_steps 200 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval \
--evaluation_strategy steps \
--prediction_loss_only \

View File

@@ -32,7 +32,7 @@ python finetune_trainer.py \
--num_train_epochs=2 \
--save_steps 3000 --eval_steps 3000 \
--logging_first_step \
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--max_length 56 --eval_max_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --sortish_sampler \

View File

@@ -24,8 +24,7 @@ python finetune_trainer.py \
--src_lang en_XX --tgt_lang ro_RO \
--freeze_embeds \
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_target_length 128 \
--val_max_target_length 128 --test_max_target_length 128 \
--max_source_length 128 --max_length 128 --eval_max_length 128 \
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \

View File

@@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
max_target_length=self.data_args.max_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,