Seq2SeqTrainer (#6769)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -189,6 +189,34 @@ from transformers import AutoModelForSeq2SeqLM
|
|||||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Fine-tuning using Seq2SeqTrainer
|
||||||
|
To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics.
|
||||||
|
|
||||||
|
To see all the possible command line options, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
**At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.**
|
||||||
|
|
||||||
|
All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory.
|
||||||
|
|
||||||
|
#### TPU Training
|
||||||
|
`Seq2SeqTrainer` supports TPU training with few caveats
|
||||||
|
1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`.
|
||||||
|
2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.)
|
||||||
|
|
||||||
|
We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed).
|
||||||
|
|
||||||
|
`builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training.
|
||||||
|
|
||||||
|
Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./builtin_trainer/train_distil_marian_enro_tpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
### Evaluation Commands
|
### Evaluation Commands
|
||||||
|
|
||||||
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
||||||
|
|||||||
9
examples/seq2seq/builtin_trainer/finetune.sh
Normal file
9
examples/seq2seq/builtin_trainer/finetune.sh
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
|
# run ./builtin_trainer/finetune.sh --help to see all the possible options
|
||||||
|
python finetune_trainer.py \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--fp16 \
|
||||||
|
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||||
|
--predict_with_generate \
|
||||||
|
--n_val 1000 \
|
||||||
|
"$@"
|
||||||
11
examples/seq2seq/builtin_trainer/finetune_tpu.sh
Normal file
11
examples/seq2seq/builtin_trainer/finetune_tpu.sh
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
export TPU_NUM_CORES=8
|
||||||
|
|
||||||
|
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||||
|
# run ./builtin_trainer/finetune_tpu.sh --help to see all the possible options
|
||||||
|
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||||
|
finetune_trainer.py \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--do_train --do_eval --evaluate_during_training \
|
||||||
|
--prediction_loss_only \
|
||||||
|
--n_val 1000 \
|
||||||
|
"$@"
|
||||||
23
examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh
Normal file
23
examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
export WANDB_PROJECT=distil-marian
|
||||||
|
export BS=64
|
||||||
|
export GAS=1
|
||||||
|
export m=sshleifer/student_marian_en_ro_6_3
|
||||||
|
export MAX_LEN=128
|
||||||
|
python finetune_trainer.py \
|
||||||
|
--tokenizer_name $m --model_name_or_path $m \
|
||||||
|
--data_dir $ENRO_DIR \
|
||||||
|
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--warmup_steps 500 --sortish_sampler \
|
||||||
|
--fp16 \
|
||||||
|
--gradient_accumulation_steps=$GAS \
|
||||||
|
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--num_train_epochs=6 \
|
||||||
|
--save_steps 3000 --eval_steps 3000 \
|
||||||
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
|
--do_train --do_eval --do_predict --evaluate_during_training\
|
||||||
|
--predict_with_generate --logging_first_step \
|
||||||
|
--task translation --label_smoothing 0.1 \
|
||||||
|
--run_name marian_en_ro_6_3 \
|
||||||
|
"$@"
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
export WANDB_PROJECT=distil-marian
|
||||||
|
export BS=64
|
||||||
|
export m=sshleifer/student_marian_en_ro_6_3
|
||||||
|
export MAX_LEN=128
|
||||||
|
export TPU_NUM_CORES=8
|
||||||
|
|
||||||
|
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||||
|
finetune_trainer.py \
|
||||||
|
--tokenizer_name $m --model_name_or_path $m \
|
||||||
|
--data_dir $ENRO_DIR \
|
||||||
|
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
|
||||||
|
--learning_rate=3e-4 \
|
||||||
|
--warmup_steps 500 \
|
||||||
|
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--num_train_epochs=6 \
|
||||||
|
--save_steps 500 --eval_steps 500 \
|
||||||
|
--logging_first_step --logging_steps 200 \
|
||||||
|
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||||
|
--do_train --do_eval --evaluate_during_training \
|
||||||
|
--prediction_loss_only \
|
||||||
|
--task translation --label_smoothing 0.1 \
|
||||||
|
--run_name marian_en_ro_6_3 \
|
||||||
|
"$@"
|
||||||
26
examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh
Normal file
26
examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
export WANDB_PROJECT=distilbart-cnn
|
||||||
|
export BS=32
|
||||||
|
export GAS=1
|
||||||
|
export m=sshleifer/student_cnn_12_6
|
||||||
|
export tok=facebook/bart-large
|
||||||
|
export MAX_TGT_LEN=142
|
||||||
|
|
||||||
|
python finetune_trainer.py \
|
||||||
|
--model_name_or_path $m --tokenizer_name $tok \
|
||||||
|
--data_dir $CNN_DIR \
|
||||||
|
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--warmup_steps 500 --sortish_sampler \
|
||||||
|
--fp16 \
|
||||||
|
--n_val 500 \
|
||||||
|
--gradient_accumulation_steps=$GAS \
|
||||||
|
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||||
|
--freeze_encoder --freeze_embeds \
|
||||||
|
--num_train_epochs=2 \
|
||||||
|
--save_steps 3000 --eval_steps 3000 \
|
||||||
|
--logging_first_step \
|
||||||
|
--max_target_length $MAX_TGT_LEN --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||||
|
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||||
|
--predict_with_generate \
|
||||||
|
--run_name distilbart-cnn-12-6 \
|
||||||
|
"$@"
|
||||||
22
examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh
Normal file
22
examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
python finetune_trainer.py \
|
||||||
|
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||||
|
--data_dir $ENRO_DIR \
|
||||||
|
--output_dir mbart_cc25_enro --overwrite_output_dir \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--warmup_steps 500 \
|
||||||
|
--fp16 \
|
||||||
|
--label_smoothing 0.1 \
|
||||||
|
--adam_eps 1e-06 \
|
||||||
|
--src_lang en_XX --tgt_lang ro_RO \
|
||||||
|
--freeze_embeds \
|
||||||
|
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||||
|
--max_source_length 128 --max_target_length 128 \
|
||||||
|
--val_max_target_length 128 --test_max_target_length 128 \
|
||||||
|
--sortish_sampler \
|
||||||
|
--num_train_epochs 6 \
|
||||||
|
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||||
|
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||||
|
--predict_with_generate --logging_first_step
|
||||||
|
--task translation \
|
||||||
|
--run_name mbart_en_ro \
|
||||||
|
"$@"
|
||||||
442
examples/seq2seq/finetune_trainer.py
Normal file
442
examples/seq2seq/finetune_trainer.py
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from seq2seq_trainer import Seq2SeqTrainer
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BartTokenizer,
|
||||||
|
EvalPrediction,
|
||||||
|
HfArgumentParser,
|
||||||
|
MBartTokenizer,
|
||||||
|
T5Tokenizer,
|
||||||
|
TrainingArguments,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
from utils import (
|
||||||
|
LegacySeq2SeqDataset,
|
||||||
|
Seq2SeqDataset,
|
||||||
|
assert_all_frozen,
|
||||||
|
calculate_bleu,
|
||||||
|
calculate_rouge,
|
||||||
|
freeze_params,
|
||||||
|
lmap,
|
||||||
|
trim_batch,
|
||||||
|
use_task_specific_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqDataCollator:
|
||||||
|
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
|
self.data_args = data_args
|
||||||
|
self.tpu_num_cores = tpu_num_cores
|
||||||
|
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||||
|
|
||||||
|
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||||
|
batch = self._encode(batch)
|
||||||
|
input_ids, attention_mask, labels = (
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["attention_mask"],
|
||||||
|
batch["labels"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
|
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||||
|
labels = torch.stack([x["labels"] for x in batch])
|
||||||
|
|
||||||
|
labels = trim_batch(labels, self.pad_token_id)
|
||||||
|
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
if isinstance(self.tokenizer, T5Tokenizer):
|
||||||
|
decoder_input_ids = self._shift_right_t5(labels)
|
||||||
|
labels = labels
|
||||||
|
else:
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||||
|
labels = labels
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _shift_right_t5(self, input_ids):
|
||||||
|
decoder_start_token_id = self.pad_token_id
|
||||||
|
|
||||||
|
assert (
|
||||||
|
decoder_start_token_id is not None
|
||||||
|
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||||
|
|
||||||
|
# shift inputs to the right
|
||||||
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
|
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||||
|
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
|
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
[x["src_texts"] for x in batch],
|
||||||
|
src_lang=self.data_args.src_lang,
|
||||||
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
|
tgt_lang=self.data_args.tgt_lang,
|
||||||
|
max_length=self.data_args.max_source_length,
|
||||||
|
max_target_length=self.data_args.max_target_length,
|
||||||
|
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||||
|
return_tensors="pt",
|
||||||
|
add_prefix_space=self.add_prefix_space,
|
||||||
|
)
|
||||||
|
return batch_encoding.data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
label_smoothing (:obj:`float`, `optional`, defaults to 0):
|
||||||
|
The label smoothing epsilon to apply (if not zero).
|
||||||
|
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size.
|
||||||
|
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||||||
|
"""
|
||||||
|
|
||||||
|
label_smoothing: Optional[float] = field(
|
||||||
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
|
||||||
|
)
|
||||||
|
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
|
||||||
|
predict_with_generate: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
|
)
|
||||||
|
config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
tokenizer_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||||
|
)
|
||||||
|
freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."})
|
||||||
|
freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataTrainingArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_dir: str = field(
|
||||||
|
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
||||||
|
)
|
||||||
|
task: Optional[str] = field(
|
||||||
|
default="summarization",
|
||||||
|
metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"},
|
||||||
|
)
|
||||||
|
max_source_length: Optional[int] = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_target_length: Optional[int] = field(
|
||||||
|
default=128,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
val_max_target_length: Optional[int] = field(
|
||||||
|
default=142,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
test_max_target_length: Optional[int] = field(
|
||||||
|
default=142,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
||||||
|
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
|
||||||
|
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
|
||||||
|
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
||||||
|
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
||||||
|
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||||
|
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
if (
|
||||||
|
os.path.exists(training_args.output_dir)
|
||||||
|
and os.listdir(training_args.output_dir)
|
||||||
|
and training_args.do_train
|
||||||
|
and not training_args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
|
training_args.local_rank,
|
||||||
|
training_args.device,
|
||||||
|
training_args.n_gpu,
|
||||||
|
bool(training_args.local_rank != -1),
|
||||||
|
training_args.fp16,
|
||||||
|
)
|
||||||
|
logger.info("Training/evaluation parameters %s", training_args)
|
||||||
|
|
||||||
|
# Set seed
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
#
|
||||||
|
# Distributed training:
|
||||||
|
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||||
|
# download model & vocab.
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_tf=".ckpt" in model_args.model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use task specific params
|
||||||
|
use_task_specific_params(model, data_args.task)
|
||||||
|
|
||||||
|
# set num_beams for evaluation
|
||||||
|
if data_args.eval_beams is not None:
|
||||||
|
model.config.num_beams = data_args.eval_beams
|
||||||
|
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
|
||||||
|
|
||||||
|
# set max length for generation
|
||||||
|
model.config.max_generate_length = data_args.val_max_target_length
|
||||||
|
|
||||||
|
# set decoder_start_token_id for MBart
|
||||||
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||||
|
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||||
|
model.config.decoder_start_token_id = decoder_start_token_id
|
||||||
|
|
||||||
|
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||||
|
def non_pad_len(tokens: np.ndarray) -> int:
|
||||||
|
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||||
|
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||||
|
pred_str = lmap(str.strip, pred_str)
|
||||||
|
label_str = lmap(str.strip, label_str)
|
||||||
|
return pred_str, label_str
|
||||||
|
|
||||||
|
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||||
|
pred_str, label_str = decode_pred(pred)
|
||||||
|
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||||
|
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||||
|
rouge.update({"gen_len": summ_len})
|
||||||
|
return rouge
|
||||||
|
|
||||||
|
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||||
|
pred_str, label_str = decode_pred(pred)
|
||||||
|
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||||
|
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||||
|
bleu.update({"gen_len": gen_len})
|
||||||
|
return bleu
|
||||||
|
|
||||||
|
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||||
|
return compute_metrics_fn
|
||||||
|
|
||||||
|
def freeze_embeds(model: torch.nn.Module):
|
||||||
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
|
try:
|
||||||
|
freeze_params(model.model.shared)
|
||||||
|
for d in [model.model.encoder, model.model.decoder]:
|
||||||
|
freeze_params(d.embed_positions)
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
except AttributeError:
|
||||||
|
freeze_params(model.shared)
|
||||||
|
for d in [model.encoder, model.decoder]:
|
||||||
|
freeze_params(d.embed_tokens)
|
||||||
|
|
||||||
|
if model_args.freeze_embeds:
|
||||||
|
freeze_embeds(model)
|
||||||
|
if model_args.freeze_encoder:
|
||||||
|
freeze_params(model.get_encoder())
|
||||||
|
assert_all_frozen(model.get_encoder())
|
||||||
|
|
||||||
|
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||||
|
|
||||||
|
# Get datasets
|
||||||
|
train_dataset = (
|
||||||
|
dataset_class(
|
||||||
|
tokenizer,
|
||||||
|
type_path="train",
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
n_obs=data_args.n_train,
|
||||||
|
max_target_length=data_args.max_target_length,
|
||||||
|
max_source_length=data_args.max_source_length,
|
||||||
|
prefix=model.config.prefix or "",
|
||||||
|
)
|
||||||
|
if training_args.do_train
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
eval_dataset = (
|
||||||
|
dataset_class(
|
||||||
|
tokenizer,
|
||||||
|
type_path="val",
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
n_obs=data_args.n_val,
|
||||||
|
max_target_length=data_args.val_max_target_length,
|
||||||
|
max_source_length=data_args.max_source_length,
|
||||||
|
prefix=model.config.prefix or "",
|
||||||
|
)
|
||||||
|
if training_args.do_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
test_dataset = (
|
||||||
|
dataset_class(
|
||||||
|
tokenizer,
|
||||||
|
type_path="test",
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
n_obs=data_args.n_test,
|
||||||
|
max_target_length=data_args.test_max_target_length,
|
||||||
|
max_source_length=data_args.max_source_length,
|
||||||
|
prefix=model.config.prefix or "",
|
||||||
|
)
|
||||||
|
if training_args.do_predict
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||||
|
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
trainer.train(
|
||||||
|
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||||
|
)
|
||||||
|
trainer.save_model()
|
||||||
|
# For convenience, we also re-save the tokenizer to the same directory,
|
||||||
|
# so that you can share your model easily on huggingface.co/models =)
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
eval_results = {}
|
||||||
|
if training_args.do_eval:
|
||||||
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
|
result = trainer.evaluate()
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
logger.info("***** Eval results *****")
|
||||||
|
for key, value in result.items():
|
||||||
|
logger.info(" %s = %s", key, value)
|
||||||
|
|
||||||
|
with open(output_eval_file, "w") as f:
|
||||||
|
json.dump(result, f)
|
||||||
|
|
||||||
|
eval_results.update(result)
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
logging.info("*** Test ***")
|
||||||
|
|
||||||
|
test_output = trainer.predict(test_dataset=test_dataset)
|
||||||
|
test_metrics = test_output.metrics
|
||||||
|
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
|
||||||
|
|
||||||
|
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
|
||||||
|
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
logger.info("***** Test results *****")
|
||||||
|
for key, value in test_metrics.items():
|
||||||
|
logger.info(" %s = %s", key, value)
|
||||||
|
|
||||||
|
with open(output_test_file, "w") as f:
|
||||||
|
json.dump(test_metrics, f)
|
||||||
|
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
|
||||||
|
test_preds = lmap(str.strip, test_preds)
|
||||||
|
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||||
|
with open(output_test_pred_file, "w") as f:
|
||||||
|
f.write("\n".join(test_preds))
|
||||||
|
|
||||||
|
return eval_results
|
||||||
|
|
||||||
|
|
||||||
|
def _mp_fn(index):
|
||||||
|
# For xla_spawn (TPUs)
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
126
examples/seq2seq/seq2seq_trainer.py
Normal file
126
examples/seq2seq/seq2seq_trainer.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
|
from transformers.trainer import get_tpu_sampler
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import label_smoothed_nll_loss
|
||||||
|
except ImportError:
|
||||||
|
from utils import label_smoothed_nll_loss
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainer(Trainer):
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
return None
|
||||||
|
elif is_torch_tpu_available():
|
||||||
|
return get_tpu_sampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
if self.args.sortish_sampler:
|
||||||
|
self.train_dataset.make_sortish_sampler(
|
||||||
|
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
RandomSampler(self.train_dataset)
|
||||||
|
if self.args.local_rank == -1
|
||||||
|
else DistributedSampler(self.train_dataset)
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs):
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
outputs = model(**inputs, use_cache=False)
|
||||||
|
logits = outputs[0]
|
||||||
|
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
|
||||||
|
|
||||||
|
def _compute_loss(self, logits, labels, ignore_index):
|
||||||
|
if self.args.label_smoothing == 0:
|
||||||
|
# Same behavior as modeling_bart.py
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||||
|
assert logits.shape[-1] == self.model.config.vocab_size
|
||||||
|
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
|
else:
|
||||||
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
|
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||||
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
prediction_loss_only (:obj:`bool`):
|
||||||
|
Whether or not to return the loss only.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
A tuple with the loss, logits and labels (each being optional).
|
||||||
|
"""
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
max_length = (
|
||||||
|
model.config.max_generate_length
|
||||||
|
if hasattr(model.config, "max_generate_length")
|
||||||
|
else model.config.max_position_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||||
|
generated_tokens = model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
use_cache=True,
|
||||||
|
num_beams=model.config.num_beams,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
generated_tokens = self._pad_tensors_to_max_len(
|
||||||
|
generated_tokens, max_length, model.config.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
labels_out = inputs.get("labels")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs[1]
|
||||||
|
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
|
||||||
|
loss = loss.mean().item()
|
||||||
|
if self.args.prediction_loss_only:
|
||||||
|
logits = None
|
||||||
|
else:
|
||||||
|
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||||
|
|
||||||
|
if self.args.prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
labels_out = labels_out.detach()
|
||||||
|
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
|
||||||
|
return (loss, logits.detach(), labels)
|
||||||
|
|
||||||
|
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||||
|
padded_tensor = pad_token_id * torch.ones(
|
||||||
|
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||||
|
)
|
||||||
|
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||||
|
return padded_tensor
|
||||||
96
examples/seq2seq/test_finetune_trainer.py
Normal file
96
examples/seq2seq/test_finetune_trainer.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||||
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
|
from .finetune_trainer import main
|
||||||
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
|
from .utils import load_json
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_NAME = MBART_TINY
|
||||||
|
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||||
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_download():
|
||||||
|
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||||
|
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||||
|
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_finetune_trainer():
|
||||||
|
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||||
|
max_len = "128"
|
||||||
|
num_train_epochs = 4
|
||||||
|
eval_steps = 2
|
||||||
|
argv = [
|
||||||
|
"--model_name_or_path",
|
||||||
|
MARIAN_MODEL,
|
||||||
|
"--data_dir",
|
||||||
|
data_dir,
|
||||||
|
"--output_dir",
|
||||||
|
output_dir,
|
||||||
|
"--overwrite_output_dir",
|
||||||
|
"--n_train",
|
||||||
|
"8",
|
||||||
|
"--n_val",
|
||||||
|
"8",
|
||||||
|
"--max_source_length",
|
||||||
|
max_len,
|
||||||
|
"--max_target_length",
|
||||||
|
max_len,
|
||||||
|
"--val_max_target_length",
|
||||||
|
max_len,
|
||||||
|
"--do_train",
|
||||||
|
"--do_eval",
|
||||||
|
"--do_predict",
|
||||||
|
"--num_train_epochs",
|
||||||
|
str(num_train_epochs),
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
"4",
|
||||||
|
"--per_device_eval_batch_size",
|
||||||
|
"4",
|
||||||
|
"--learning_rate",
|
||||||
|
"3e-4",
|
||||||
|
"--warmup_steps",
|
||||||
|
"8",
|
||||||
|
"--evaluate_during_training",
|
||||||
|
"--predict_with_generate",
|
||||||
|
"--logging_steps",
|
||||||
|
0,
|
||||||
|
"--save_steps",
|
||||||
|
str(eval_steps),
|
||||||
|
"--eval_steps",
|
||||||
|
str(eval_steps),
|
||||||
|
"--sortish_sampler",
|
||||||
|
"--label_smoothing",
|
||||||
|
"0.1",
|
||||||
|
"--task",
|
||||||
|
"translation",
|
||||||
|
]
|
||||||
|
|
||||||
|
testargs = ["finetune_trainer.py"] + argv
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Check metrics
|
||||||
|
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||||
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
|
first_step_stats = eval_metrics[0]
|
||||||
|
last_step_stats = eval_metrics[-1]
|
||||||
|
|
||||||
|
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||||
|
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||||
|
|
||||||
|
# test if do_predict saves generations and metrics
|
||||||
|
contents = os.listdir(output_dir)
|
||||||
|
contents = {os.path.basename(p) for p in contents}
|
||||||
|
assert "test_generations.txt" in contents
|
||||||
|
assert "test_results.json" in contents
|
||||||
72
examples/seq2seq/xla_spawn.py
Normal file
72
examples/seq2seq/xla_spawn.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
A simple launcher script for TPU training
|
||||||
|
|
||||||
|
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
|
||||||
|
|
||||||
|
::
|
||||||
|
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
|
||||||
|
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
|
||||||
|
arguments of your training script)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from argparse import REMAINDER, ArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
Helper function parsing the command line options
|
||||||
|
@retval ArgumentParser
|
||||||
|
"""
|
||||||
|
parser = ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"PyTorch TPU distributed training launch "
|
||||||
|
"helper utility that will spawn up "
|
||||||
|
"multiple distributed processes"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional arguments for the launch helper
|
||||||
|
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
|
||||||
|
|
||||||
|
# positional
|
||||||
|
parser.add_argument(
|
||||||
|
"training_script",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"The full path to the single TPU training "
|
||||||
|
"program/script to be launched in parallel, "
|
||||||
|
"followed by all the arguments for the "
|
||||||
|
"training script"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# rest from the training program
|
||||||
|
parser.add_argument("training_script_args", nargs=REMAINDER)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Import training_script as a module.
|
||||||
|
script_fpath = Path(args.training_script)
|
||||||
|
sys.path.append(str(script_fpath.parent.resolve()))
|
||||||
|
mod_name = script_fpath.stem
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
|
||||||
|
# Patch sys.argv
|
||||||
|
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
|
||||||
|
|
||||||
|
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user