Seq2seq trainer (#9241)
* Add label smoothing in Trainer * Add options for scheduler and Adafactor in Trainer * Put Seq2SeqTrainer in the main lib * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments and adapt scripts * Documentation * Move test not using script to tests folder Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -20,9 +20,16 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||
from transformers.training_args import ParallelMode
|
||||
from utils import (
|
||||
@@ -86,14 +93,14 @@ class DataTrainingArguments:
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
max_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
eval_max_length: Optional[int] = field(
|
||||
default=142,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
@@ -101,13 +108,6 @@ class DataTrainingArguments:
|
||||
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
|
||||
},
|
||||
)
|
||||
test_max_target_length: Optional[int] = field(
|
||||
default=142,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
||||
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
|
||||
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
|
||||
@@ -233,7 +233,7 @@ def main():
|
||||
type_path="train",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_train,
|
||||
max_target_length=data_args.max_target_length,
|
||||
max_target_length=data_args.max_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
@@ -246,7 +246,7 @@ def main():
|
||||
type_path="val",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_val,
|
||||
max_target_length=data_args.val_max_target_length,
|
||||
max_target_length=data_args.eval_max_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
@@ -259,7 +259,7 @@ def main():
|
||||
type_path="test",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_test,
|
||||
max_target_length=data_args.test_max_target_length,
|
||||
max_target_length=data_args.eval_max_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
@@ -273,13 +273,12 @@ def main():
|
||||
)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=compute_metrics_fn,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
all_metrics = {}
|
||||
@@ -310,7 +309,9 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(metric_key_prefix="val")
|
||||
metrics = trainer.evaluate(
|
||||
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
|
||||
)
|
||||
metrics["val_n_objs"] = data_args.n_val
|
||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||
|
||||
@@ -322,7 +323,12 @@ def main():
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
||||
test_output = trainer.predict(
|
||||
test_dataset=test_dataset,
|
||||
metric_key_prefix="test",
|
||||
max_length=data_args.eval_max_length,
|
||||
num_beams=data_args.eval_beams,
|
||||
)
|
||||
metrics = test_output.metrics
|
||||
metrics["test_n_objs"] = data_args.n_test
|
||||
|
||||
|
||||
@@ -17,8 +17,7 @@ import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_apex_available, is_datasets_available
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
@@ -31,8 +30,7 @@ from transformers.testing_utils import (
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .finetune_trainer import main
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -120,119 +118,6 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
@slow
|
||||
def test_finetune_bert2bert(self):
|
||||
if not is_datasets_available():
|
||||
return
|
||||
|
||||
import datasets
|
||||
|
||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
||||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||
bert2bert.config.max_length = 128
|
||||
|
||||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||
|
||||
train_dataset = train_dataset.select(range(32))
|
||||
val_dataset = val_dataset.select(range(16))
|
||||
|
||||
rouge = datasets.load_metric("rouge")
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def _map_to_encoder_decoder_inputs(batch):
|
||||
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
||||
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
||||
batch["input_ids"] = inputs.input_ids
|
||||
batch["attention_mask"] = inputs.attention_mask
|
||||
|
||||
batch["decoder_input_ids"] = outputs.input_ids
|
||||
batch["labels"] = outputs.input_ids.copy()
|
||||
batch["labels"] = [
|
||||
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
||||
]
|
||||
batch["decoder_attention_mask"] = outputs.attention_mask
|
||||
|
||||
assert all([len(x) == 512 for x in inputs.input_ids])
|
||||
assert all([len(x) == 128 for x in outputs.input_ids])
|
||||
|
||||
return batch
|
||||
|
||||
def _compute_metrics(pred):
|
||||
labels_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
|
||||
# all unnecessary tokens are removed
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||
"rouge2"
|
||||
].mid
|
||||
|
||||
return {
|
||||
"rouge2_precision": round(rouge_output.precision, 4),
|
||||
"rouge2_recall": round(rouge_output.recall, 4),
|
||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||
}
|
||||
|
||||
# map train dataset
|
||||
train_dataset = train_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
train_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
# same for validation dataset
|
||||
val_dataset = val_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
val_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
output_dir=output_dir,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
predict_with_generate=True,
|
||||
evaluation_strategy="steps",
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
warmup_steps=0,
|
||||
eval_steps=2,
|
||||
logging_steps=2,
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=bert2bert,
|
||||
args=training_args,
|
||||
compute_metrics=_compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
)
|
||||
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(
|
||||
self,
|
||||
eval_steps: int,
|
||||
@@ -252,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
--n_train 8
|
||||
--n_val 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--max_length {max_len}
|
||||
--eval_max_length {max_len}
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_predict
|
||||
|
||||
@@ -29,7 +29,7 @@ python finetune_trainer.py \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --logging_first_step \
|
||||
|
||||
@@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 500 --eval_steps 500 \
|
||||
--logging_first_step --logging_steps 200 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
|
||||
--do_train --do_eval \
|
||||
--evaluation_strategy steps \
|
||||
--prediction_loss_only \
|
||||
|
||||
@@ -32,7 +32,7 @@ python finetune_trainer.py \
|
||||
--num_train_epochs=2 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--logging_first_step \
|
||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--max_length 56 --eval_max_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --sortish_sampler \
|
||||
|
||||
@@ -24,8 +24,7 @@ python finetune_trainer.py \
|
||||
--src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_embeds \
|
||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||
--max_source_length 128 --max_target_length 128 \
|
||||
--val_max_target_length 128 --test_max_target_length 128 \
|
||||
--max_source_length 128 --max_length 128 --eval_max_length 128 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs 6 \
|
||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||
|
||||
@@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
max_target_length=self.data_args.max_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
**self.dataset_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user