Seq2SeqTrainer (#6769)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -189,6 +189,34 @@ from transformers import AutoModelForSeq2SeqLM
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
### Fine-tuning using Seq2SeqTrainer
|
||||
To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics.
|
||||
|
||||
To see all the possible command line options, run:
|
||||
|
||||
```bash
|
||||
./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help
|
||||
```
|
||||
|
||||
**At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.**
|
||||
|
||||
All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory.
|
||||
|
||||
#### TPU Training
|
||||
`Seq2SeqTrainer` supports TPU training with few caveats
|
||||
1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`.
|
||||
2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.)
|
||||
|
||||
We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed).
|
||||
|
||||
`builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training.
|
||||
|
||||
Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins.
|
||||
|
||||
```bash
|
||||
./builtin_trainer/train_distil_marian_enro_tpu.sh
|
||||
```
|
||||
|
||||
### Evaluation Commands
|
||||
|
||||
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
||||
|
||||
9
examples/seq2seq/builtin_trainer/finetune.sh
Normal file
9
examples/seq2seq/builtin_trainer/finetune.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./builtin_trainer/finetune.sh --help to see all the possible options
|
||||
python finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||
--predict_with_generate \
|
||||
--n_val 1000 \
|
||||
"$@"
|
||||
11
examples/seq2seq/builtin_trainer/finetune_tpu.sh
Normal file
11
examples/seq2seq/builtin_trainer/finetune_tpu.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
export TPU_NUM_CORES=8
|
||||
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./builtin_trainer/finetune_tpu.sh --help to see all the possible options
|
||||
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--do_train --do_eval --evaluate_during_training \
|
||||
--prediction_loss_only \
|
||||
--n_val 1000 \
|
||||
"$@"
|
||||
23
examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh
Normal file
23
examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh
Normal file
@@ -0,0 +1,23 @@
|
||||
export WANDB_PROJECT=distil-marian
|
||||
export BS=64
|
||||
export GAS=1
|
||||
export m=sshleifer/student_marian_en_ro_6_3
|
||||
export MAX_LEN=128
|
||||
python finetune_trainer.py \
|
||||
--tokenizer_name $m --model_name_or_path $m \
|
||||
--data_dir $ENRO_DIR \
|
||||
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
|
||||
--learning_rate=3e-4 \
|
||||
--warmup_steps 500 --sortish_sampler \
|
||||
--fp16 \
|
||||
--gradient_accumulation_steps=$GAS \
|
||||
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--do_train --do_eval --do_predict --evaluate_during_training\
|
||||
--predict_with_generate --logging_first_step \
|
||||
--task translation --label_smoothing 0.1 \
|
||||
--run_name marian_en_ro_6_3 \
|
||||
"$@"
|
||||
@@ -0,0 +1,24 @@
|
||||
export WANDB_PROJECT=distil-marian
|
||||
export BS=64
|
||||
export m=sshleifer/student_marian_en_ro_6_3
|
||||
export MAX_LEN=128
|
||||
export TPU_NUM_CORES=8
|
||||
|
||||
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
finetune_trainer.py \
|
||||
--tokenizer_name $m --model_name_or_path $m \
|
||||
--data_dir $ENRO_DIR \
|
||||
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
|
||||
--learning_rate=3e-4 \
|
||||
--warmup_steps 500 \
|
||||
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 500 --eval_steps 500 \
|
||||
--logging_first_step --logging_steps 200 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--do_train --do_eval --evaluate_during_training \
|
||||
--prediction_loss_only \
|
||||
--task translation --label_smoothing 0.1 \
|
||||
--run_name marian_en_ro_6_3 \
|
||||
"$@"
|
||||
26
examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh
Normal file
26
examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh
Normal file
@@ -0,0 +1,26 @@
|
||||
export WANDB_PROJECT=distilbart-cnn
|
||||
export BS=32
|
||||
export GAS=1
|
||||
export m=sshleifer/student_cnn_12_6
|
||||
export tok=facebook/bart-large
|
||||
export MAX_TGT_LEN=142
|
||||
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path $m --tokenizer_name $tok \
|
||||
--data_dir $CNN_DIR \
|
||||
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 --sortish_sampler \
|
||||
--fp16 \
|
||||
--n_val 500 \
|
||||
--gradient_accumulation_steps=$GAS \
|
||||
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=2 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--logging_first_step \
|
||||
--max_target_length $MAX_TGT_LEN --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||
--predict_with_generate \
|
||||
--run_name distilbart-cnn-12-6 \
|
||||
"$@"
|
||||
22
examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh
Normal file
22
examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||
--data_dir $ENRO_DIR \
|
||||
--output_dir mbart_cc25_enro --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 \
|
||||
--fp16 \
|
||||
--label_smoothing 0.1 \
|
||||
--adam_eps 1e-06 \
|
||||
--src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_embeds \
|
||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||
--max_source_length 128 --max_target_length 128 \
|
||||
--val_max_target_length 128 --test_max_target_length 128 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs 6 \
|
||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||
--do_train --do_eval --do_predict --evaluate_during_training \
|
||||
--predict_with_generate --logging_first_step
|
||||
--task translation \
|
||||
--run_name mbart_en_ro \
|
||||
"$@"
|
||||
442
examples/seq2seq/finetune_trainer.py
Normal file
442
examples/seq2seq/finetune_trainer.py
Normal file
@@ -0,0 +1,442 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
T5Tokenizer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
freeze_params,
|
||||
lmap,
|
||||
trim_batch,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["labels"],
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||
labels = torch.stack([x["labels"] for x in batch])
|
||||
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
labels = labels
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
labels = labels
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
decoder_start_token_id = self.pad_token_id
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.data_args.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.data_args.tgt_lang,
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Parameters:
|
||||
label_smoothing (:obj:`float`, `optional`, defaults to 0):
|
||||
The label smoothing epsilon to apply (if not zero).
|
||||
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size.
|
||||
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||||
"""
|
||||
|
||||
label_smoothing: Optional[float] = field(
|
||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
|
||||
)
|
||||
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
|
||||
predict_with_generate: bool = field(
|
||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."})
|
||||
freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
data_dir: str = field(
|
||||
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
||||
)
|
||||
task: Optional[str] = field(
|
||||
default="summarization",
|
||||
metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=142,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
test_max_target_length: Optional[int] = field(
|
||||
default=142,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
||||
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
|
||||
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
|
||||
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=".ckpt" in model_args.model_name_or_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# use task specific params
|
||||
use_task_specific_params(model, data_args.task)
|
||||
|
||||
# set num_beams for evaluation
|
||||
if data_args.eval_beams is not None:
|
||||
model.config.num_beams = data_args.eval_beams
|
||||
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
|
||||
|
||||
# set max length for generation
|
||||
model.config.max_generate_length = data_args.val_max_target_length
|
||||
|
||||
# set decoder_start_token_id for MBart
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
model.config.decoder_start_token_id = decoder_start_token_id
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
def freeze_embeds(model: torch.nn.Module):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
try:
|
||||
freeze_params(model.model.shared)
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
except AttributeError:
|
||||
freeze_params(model.shared)
|
||||
for d in [model.encoder, model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
if model_args.freeze_encoder:
|
||||
freeze_params(model.get_encoder())
|
||||
assert_all_frozen(model.get_encoder())
|
||||
|
||||
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
dataset_class(
|
||||
tokenizer,
|
||||
type_path="train",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_train,
|
||||
max_target_length=data_args.max_target_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
dataset_class(
|
||||
tokenizer,
|
||||
type_path="val",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_val,
|
||||
max_target_length=data_args.val_max_target_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
test_dataset = (
|
||||
dataset_class(
|
||||
tokenizer,
|
||||
type_path="test",
|
||||
data_dir=data_args.data_dir,
|
||||
n_obs=data_args.n_test,
|
||||
max_target_length=data_args.test_max_target_length,
|
||||
max_source_length=data_args.max_source_length,
|
||||
prefix=model.config.prefix or "",
|
||||
)
|
||||
if training_args.do_predict
|
||||
else None
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model()
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
if trainer.is_world_process_zero():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
eval_results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
result = trainer.evaluate()
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_eval_file, "w") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
eval_results.update(result)
|
||||
|
||||
if training_args.do_predict:
|
||||
logging.info("*** Test ***")
|
||||
|
||||
test_output = trainer.predict(test_dataset=test_dataset)
|
||||
test_metrics = test_output.metrics
|
||||
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
|
||||
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
logger.info("***** Test results *****")
|
||||
for key, value in test_metrics.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
|
||||
with open(output_test_file, "w") as f:
|
||||
json.dump(test_metrics, f)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
|
||||
test_preds = lmap(str.strip, test_preds)
|
||||
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_pred_file, "w") as f:
|
||||
f.write("\n".join(test_preds))
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
126
examples/seq2seq/seq2seq_trainer.py
Normal file
126
examples/seq2seq/seq2seq_trainer.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.trainer import get_tpu_sampler
|
||||
|
||||
|
||||
try:
|
||||
from .utils import label_smoothed_nll_loss
|
||||
except ImportError:
|
||||
from utils import label_smoothed_nll_loss
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
return get_tpu_sampler(self.train_dataset)
|
||||
else:
|
||||
if self.args.sortish_sampler:
|
||||
self.train_dataset.make_sortish_sampler(
|
||||
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
|
||||
)
|
||||
|
||||
return (
|
||||
RandomSampler(self.train_dataset)
|
||||
if self.args.local_rank == -1
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
logits = outputs[0]
|
||||
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
|
||||
|
||||
def _compute_loss(self, logits, labels, ignore_index):
|
||||
if self.args.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
assert logits.shape[-1] == self.model.config.vocab_size
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
|
||||
)
|
||||
return loss
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
A tuple with the loss, logits and labels (each being optional).
|
||||
"""
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
max_length = (
|
||||
model.config.max_generate_length
|
||||
if hasattr(model.config, "max_generate_length")
|
||||
else model.config.max_position_embeddings
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||
generated_tokens = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
use_cache=True,
|
||||
num_beams=model.config.num_beams,
|
||||
max_length=max_length,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
generated_tokens = self._pad_tensors_to_max_len(
|
||||
generated_tokens, max_length, model.config.pad_token_id
|
||||
)
|
||||
|
||||
labels_out = inputs.get("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[1]
|
||||
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
|
||||
loss = loss.mean().item()
|
||||
if self.args.prediction_loss_only:
|
||||
logits = None
|
||||
else:
|
||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
labels_out = labels_out.detach()
|
||||
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||
padded_tensor = pad_token_id * torch.ones(
|
||||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||
return padded_tensor
|
||||
96
examples/seq2seq/test_finetune_trainer.py
Normal file
96
examples/seq2seq/test_finetune_trainer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
MODEL_NAME = MBART_TINY
|
||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
@slow
|
||||
def test_model_download():
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer():
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = tempfile.mkdtemp(prefix="marian_output")
|
||||
max_len = "128"
|
||||
num_train_epochs = 4
|
||||
eval_steps = 2
|
||||
argv = [
|
||||
"--model_name_or_path",
|
||||
MARIAN_MODEL,
|
||||
"--data_dir",
|
||||
data_dir,
|
||||
"--output_dir",
|
||||
output_dir,
|
||||
"--overwrite_output_dir",
|
||||
"--n_train",
|
||||
"8",
|
||||
"--n_val",
|
||||
"8",
|
||||
"--max_source_length",
|
||||
max_len,
|
||||
"--max_target_length",
|
||||
max_len,
|
||||
"--val_max_target_length",
|
||||
max_len,
|
||||
"--do_train",
|
||||
"--do_eval",
|
||||
"--do_predict",
|
||||
"--num_train_epochs",
|
||||
str(num_train_epochs),
|
||||
"--per_device_train_batch_size",
|
||||
"4",
|
||||
"--per_device_eval_batch_size",
|
||||
"4",
|
||||
"--learning_rate",
|
||||
"3e-4",
|
||||
"--warmup_steps",
|
||||
"8",
|
||||
"--evaluate_during_training",
|
||||
"--predict_with_generate",
|
||||
"--logging_steps",
|
||||
0,
|
||||
"--save_steps",
|
||||
str(eval_steps),
|
||||
"--eval_steps",
|
||||
str(eval_steps),
|
||||
"--sortish_sampler",
|
||||
"--label_smoothing",
|
||||
"0.1",
|
||||
"--task",
|
||||
"translation",
|
||||
]
|
||||
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
# Check metrics
|
||||
logs = load_json(os.path.join(output_dir, "log_history.json"))
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
last_step_stats = eval_metrics[-1]
|
||||
|
||||
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
|
||||
# test if do_predict saves generations and metrics
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
72
examples/seq2seq/xla_spawn.py
Normal file
72
examples/seq2seq/xla_spawn.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
A simple launcher script for TPU training
|
||||
|
||||
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
|
||||
|
||||
::
|
||||
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
|
||||
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
|
||||
arguments of your training script)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Helper function parsing the command line options
|
||||
@retval ArgumentParser
|
||||
"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"PyTorch TPU distributed training launch "
|
||||
"helper utility that will spawn up "
|
||||
"multiple distributed processes"
|
||||
)
|
||||
)
|
||||
|
||||
# Optional arguments for the launch helper
|
||||
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
|
||||
|
||||
# positional
|
||||
parser.add_argument(
|
||||
"training_script",
|
||||
type=str,
|
||||
help=(
|
||||
"The full path to the single TPU training "
|
||||
"program/script to be launched in parallel, "
|
||||
"followed by all the arguments for the "
|
||||
"training script"
|
||||
),
|
||||
)
|
||||
|
||||
# rest from the training program
|
||||
parser.add_argument("training_script_args", nargs=REMAINDER)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Import training_script as a module.
|
||||
script_fpath = Path(args.training_script)
|
||||
sys.path.append(str(script_fpath.parent.resolve()))
|
||||
mod_name = script_fpath.stem
|
||||
mod = importlib.import_module(mod_name)
|
||||
|
||||
# Patch sys.argv
|
||||
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
|
||||
|
||||
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user