From 9e68d075a4100906509170498480823e7e61874a Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 25 Sep 2020 04:16:58 +0530 Subject: [PATCH] Seq2SeqTrainer (#6769) Co-authored-by: Sam Shleifer --- examples/seq2seq/README.md | 28 ++ examples/seq2seq/builtin_trainer/finetune.sh | 9 + .../seq2seq/builtin_trainer/finetune_tpu.sh | 11 + .../train_distil_marian_enro.sh | 23 + .../train_distil_marian_enro_tpu.sh | 24 + .../builtin_trainer/train_distilbart_cnn.sh | 26 ++ .../builtin_trainer/train_mbart_cc25_enro.sh | 22 + examples/seq2seq/finetune_trainer.py | 442 ++++++++++++++++++ examples/seq2seq/seq2seq_trainer.py | 126 +++++ examples/seq2seq/test_finetune_trainer.py | 96 ++++ examples/seq2seq/xla_spawn.py | 72 +++ 11 files changed, 879 insertions(+) create mode 100644 examples/seq2seq/builtin_trainer/finetune.sh create mode 100644 examples/seq2seq/builtin_trainer/finetune_tpu.sh create mode 100644 examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh create mode 100644 examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh create mode 100644 examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh create mode 100644 examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh create mode 100644 examples/seq2seq/finetune_trainer.py create mode 100644 examples/seq2seq/seq2seq_trainer.py create mode 100644 examples/seq2seq/test_finetune_trainer.py create mode 100644 examples/seq2seq/xla_spawn.py diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index b08794fb76..93afa25b7b 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -189,6 +189,34 @@ from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') ``` +### Fine-tuning using Seq2SeqTrainer +To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics. + +To see all the possible command line options, run: + +```bash +./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help +``` + +**At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.** + +All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory. + +#### TPU Training +`Seq2SeqTrainer` supports TPU training with few caveats +1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`. +2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.) + +We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed). + +`builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training. + +Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins. + +```bash +./builtin_trainer/train_distil_marian_enro_tpu.sh +``` + ### Evaluation Commands To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. diff --git a/examples/seq2seq/builtin_trainer/finetune.sh b/examples/seq2seq/builtin_trainer/finetune.sh new file mode 100644 index 0000000000..65f207c21a --- /dev/null +++ b/examples/seq2seq/builtin_trainer/finetune.sh @@ -0,0 +1,9 @@ +# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path +# run ./builtin_trainer/finetune.sh --help to see all the possible options +python finetune_trainer.py \ + --learning_rate=3e-5 \ + --fp16 \ + --do_train --do_eval --do_predict --evaluate_during_training \ + --predict_with_generate \ + --n_val 1000 \ + "$@" diff --git a/examples/seq2seq/builtin_trainer/finetune_tpu.sh b/examples/seq2seq/builtin_trainer/finetune_tpu.sh new file mode 100644 index 0000000000..8bd367c852 --- /dev/null +++ b/examples/seq2seq/builtin_trainer/finetune_tpu.sh @@ -0,0 +1,11 @@ +export TPU_NUM_CORES=8 + +# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path +# run ./builtin_trainer/finetune_tpu.sh --help to see all the possible options +python xla_spawn.py --num_cores $TPU_NUM_CORES \ + finetune_trainer.py \ + --learning_rate=3e-5 \ + --do_train --do_eval --evaluate_during_training \ + --prediction_loss_only \ + --n_val 1000 \ + "$@" diff --git a/examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh b/examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh new file mode 100644 index 0000000000..86274c2e43 --- /dev/null +++ b/examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh @@ -0,0 +1,23 @@ +export WANDB_PROJECT=distil-marian +export BS=64 +export GAS=1 +export m=sshleifer/student_marian_en_ro_6_3 +export MAX_LEN=128 +python finetune_trainer.py \ + --tokenizer_name $m --model_name_or_path $m \ + --data_dir $ENRO_DIR \ + --output_dir marian_en_ro_6_3 --overwrite_output_dir \ + --learning_rate=3e-4 \ + --warmup_steps 500 --sortish_sampler \ + --fp16 \ + --gradient_accumulation_steps=$GAS \ + --per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \ + --freeze_encoder --freeze_embeds \ + --num_train_epochs=6 \ + --save_steps 3000 --eval_steps 3000 \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --do_train --do_eval --do_predict --evaluate_during_training\ + --predict_with_generate --logging_first_step \ + --task translation --label_smoothing 0.1 \ + --run_name marian_en_ro_6_3 \ + "$@" diff --git a/examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh b/examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh new file mode 100644 index 0000000000..93256f902d --- /dev/null +++ b/examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh @@ -0,0 +1,24 @@ +export WANDB_PROJECT=distil-marian +export BS=64 +export m=sshleifer/student_marian_en_ro_6_3 +export MAX_LEN=128 +export TPU_NUM_CORES=8 + +python xla_spawn.py --num_cores $TPU_NUM_CORES \ + finetune_trainer.py \ + --tokenizer_name $m --model_name_or_path $m \ + --data_dir $ENRO_DIR \ + --output_dir marian_en_ro_6_3 --overwrite_output_dir \ + --learning_rate=3e-4 \ + --warmup_steps 500 \ + --per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \ + --freeze_encoder --freeze_embeds \ + --num_train_epochs=6 \ + --save_steps 500 --eval_steps 500 \ + --logging_first_step --logging_steps 200 \ + --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --do_train --do_eval --evaluate_during_training \ + --prediction_loss_only \ + --task translation --label_smoothing 0.1 \ + --run_name marian_en_ro_6_3 \ + "$@" diff --git a/examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh b/examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh new file mode 100644 index 0000000000..6317481377 --- /dev/null +++ b/examples/seq2seq/builtin_trainer/train_distilbart_cnn.sh @@ -0,0 +1,26 @@ +export WANDB_PROJECT=distilbart-cnn +export BS=32 +export GAS=1 +export m=sshleifer/student_cnn_12_6 +export tok=facebook/bart-large +export MAX_TGT_LEN=142 + +python finetune_trainer.py \ + --model_name_or_path $m --tokenizer_name $tok \ + --data_dir $CNN_DIR \ + --output_dir distilbart-cnn-12-6 --overwrite_output_dir \ + --learning_rate=3e-5 \ + --warmup_steps 500 --sortish_sampler \ + --fp16 \ + --n_val 500 \ + --gradient_accumulation_steps=$GAS \ + --per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \ + --freeze_encoder --freeze_embeds \ + --num_train_epochs=2 \ + --save_steps 3000 --eval_steps 3000 \ + --logging_first_step \ + --max_target_length $MAX_TGT_LEN --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \ + --do_train --do_eval --do_predict --evaluate_during_training \ + --predict_with_generate \ + --run_name distilbart-cnn-12-6 \ + "$@" diff --git a/examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh b/examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh new file mode 100644 index 0000000000..e113004807 --- /dev/null +++ b/examples/seq2seq/builtin_trainer/train_mbart_cc25_enro.sh @@ -0,0 +1,22 @@ +python finetune_trainer.py \ + --model_name_or_path=facebook/mbart-large-cc25 \ + --data_dir $ENRO_DIR \ + --output_dir mbart_cc25_enro --overwrite_output_dir \ + --learning_rate=3e-5 \ + --warmup_steps 500 \ + --fp16 \ + --label_smoothing 0.1 \ + --adam_eps 1e-06 \ + --src_lang en_XX --tgt_lang ro_RO \ + --freeze_embeds \ + --per_device_train_batch_size=4 --per_device_eval_batch_size=4 \ + --max_source_length 128 --max_target_length 128 \ + --val_max_target_length 128 --test_max_target_length 128 \ + --sortish_sampler \ + --num_train_epochs 6 \ + --save_steps 25000 --eval_steps 25000 --logging_steps 1000 \ + --do_train --do_eval --do_predict --evaluate_during_training \ + --predict_with_generate --logging_first_step + --task translation \ + --run_name mbart_en_ro \ + "$@" diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py new file mode 100644 index 0000000000..9dca6f3974 --- /dev/null +++ b/examples/seq2seq/finetune_trainer.py @@ -0,0 +1,442 @@ +import json +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from seq2seq_trainer import Seq2SeqTrainer +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + BartTokenizer, + EvalPrediction, + HfArgumentParser, + MBartTokenizer, + T5Tokenizer, + TrainingArguments, + set_seed, +) +from transformers.modeling_bart import shift_tokens_right +from utils import ( + LegacySeq2SeqDataset, + Seq2SeqDataset, + assert_all_frozen, + calculate_bleu, + calculate_rouge, + freeze_params, + lmap, + trim_batch, + use_task_specific_params, +) + + +logger = logging.getLogger(__name__) + + +class Seq2SeqDataCollator: + def __init__(self, tokenizer, data_args, tpu_num_cores=None): + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + self.data_args = data_args + self.tpu_num_cores = tpu_num_cores + self.add_prefix_space = isinstance(tokenizer, BartTokenizer) + + def __call__(self, batch) -> Dict[str, torch.Tensor]: + if hasattr(self.tokenizer, "prepare_seq2seq_batch"): + batch = self._encode(batch) + input_ids, attention_mask, labels = ( + batch["input_ids"], + batch["attention_mask"], + batch["labels"], + ) + else: + input_ids = torch.stack([x["input_ids"] for x in batch]) + attention_mask = torch.stack([x["attention_mask"] for x in batch]) + labels = torch.stack([x["labels"] for x in batch]) + + labels = trim_batch(labels, self.pad_token_id) + input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) + + if isinstance(self.tokenizer, T5Tokenizer): + decoder_input_ids = self._shift_right_t5(labels) + labels = labels + else: + decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) + labels = labels + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "labels": labels, + } + return batch + + def _shift_right_t5(self, input_ids): + decoder_start_token_id = self.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + return shifted_input_ids + + def _encode(self, batch) -> Dict[str, torch.Tensor]: + batch_encoding = self.tokenizer.prepare_seq2seq_batch( + [x["src_texts"] for x in batch], + src_lang=self.data_args.src_lang, + tgt_texts=[x["tgt_texts"] for x in batch], + tgt_lang=self.data_args.tgt_lang, + max_length=self.data_args.max_source_length, + max_target_length=self.data_args.max_target_length, + padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack + return_tensors="pt", + add_prefix_space=self.add_prefix_space, + ) + return batch_encoding.data + + +@dataclass +class Seq2SeqTrainingArguments(TrainingArguments): + """ + Parameters: + label_smoothing (:obj:`float`, `optional`, defaults to 0): + The label smoothing epsilon to apply (if not zero). + sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size. + predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + """ + + label_smoothing: Optional[float] = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} + ) + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."}) + freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."}) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + data_dir: str = field( + metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} + ) + task: Optional[str] = field( + default="summarization", + metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=142, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + test_max_target_length: Optional[int] = field( + default=142, + metadata={ + "help": "The maximum total sequence length for test target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."}) + n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."}) + n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."}) + src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) + tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) + eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=".ckpt" in model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + ) + + # use task specific params + use_task_specific_params(model, data_args.task) + + # set num_beams for evaluation + if data_args.eval_beams is not None: + model.config.num_beams = data_args.eval_beams + assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1" + + # set max length for generation + model.config.max_generate_length = data_args.val_max_target_length + + # set decoder_start_token_id for MBart + if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): + decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] + model.config.decoder_start_token_id = decoder_start_token_id + + def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: + def non_pad_len(tokens: np.ndarray) -> int: + return np.count_nonzero(tokens != tokenizer.pad_token_id) + + def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: + pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + pred_str = lmap(str.strip, pred_str) + label_str = lmap(str.strip, label_str) + return pred_str, label_str + + def summarization_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + rouge: Dict = calculate_rouge(pred_str, label_str) + summ_len = np.mean(lmap(non_pad_len, pred.predictions)) + rouge.update({"gen_len": summ_len}) + return rouge + + def translation_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + bleu: Dict = calculate_bleu(pred_str, label_str) + gen_len = np.mean(lmap(non_pad_len, pred.predictions)) + bleu.update({"gen_len": gen_len}) + return bleu + + compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics + return compute_metrics_fn + + def freeze_embeds(model: torch.nn.Module): + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" + try: + freeze_params(model.model.shared) + for d in [model.model.encoder, model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + except AttributeError: + freeze_params(model.shared) + for d in [model.encoder, model.decoder]: + freeze_params(d.embed_tokens) + + if model_args.freeze_embeds: + freeze_embeds(model) + if model_args.freeze_encoder: + freeze_params(model.get_encoder()) + assert_all_frozen(model.get_encoder()) + + dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset + + # Get datasets + train_dataset = ( + dataset_class( + tokenizer, + type_path="train", + data_dir=data_args.data_dir, + n_obs=data_args.n_train, + max_target_length=data_args.max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_train + else None + ) + eval_dataset = ( + dataset_class( + tokenizer, + type_path="val", + data_dir=data_args.data_dir, + n_obs=data_args.n_val, + max_target_length=data_args.val_max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_eval + else None + ) + test_dataset = ( + dataset_class( + tokenizer, + type_path="test", + data_dir=data_args.data_dir, + n_obs=data_args.n_test, + max_target_length=data_args.test_max_target_length, + max_source_length=data_args.max_source_length, + prefix=model.config.prefix or "", + ) + if training_args.do_predict + else None + ) + + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), + compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, + ) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + if trainer.is_world_process_zero(): + tokenizer.save_pretrained(training_args.output_dir) + + # Evaluation + eval_results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + result = trainer.evaluate() + + output_eval_file = os.path.join(training_args.output_dir, "eval_results.json") + if trainer.is_world_process_zero(): + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + + with open(output_eval_file, "w") as f: + json.dump(result, f) + + eval_results.update(result) + + if training_args.do_predict: + logging.info("*** Test ***") + + test_output = trainer.predict(test_dataset=test_dataset) + test_metrics = test_output.metrics + test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()} + + output_test_file = os.path.join(training_args.output_dir, "test_results.json") + + if trainer.is_world_process_zero(): + logger.info("***** Test results *****") + for key, value in test_metrics.items(): + logger.info(" %s = %s", key, value) + + with open(output_test_file, "w") as f: + json.dump(test_metrics, f) + + if training_args.predict_with_generate: + test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True) + test_preds = lmap(str.strip, test_preds) + output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt") + with open(output_test_pred_file, "w") as f: + f.write("\n".join(test_preds)) + + return eval_results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py new file mode 100644 index 0000000000..195eb4768e --- /dev/null +++ b/examples/seq2seq/seq2seq_trainer.py @@ -0,0 +1,126 @@ +import logging +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import DistributedSampler, RandomSampler + +from transformers import Trainer +from transformers.file_utils import is_torch_tpu_available +from transformers.trainer import get_tpu_sampler + + +try: + from .utils import label_smoothed_nll_loss +except ImportError: + from utils import label_smoothed_nll_loss + + +logger = logging.getLogger(__name__) + + +class Seq2SeqTrainer(Trainer): + def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: + if isinstance(self.train_dataset, torch.utils.data.IterableDataset): + return None + elif is_torch_tpu_available(): + return get_tpu_sampler(self.train_dataset) + else: + if self.args.sortish_sampler: + self.train_dataset.make_sortish_sampler( + self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1 + ) + + return ( + RandomSampler(self.train_dataset) + if self.args.local_rank == -1 + else DistributedSampler(self.train_dataset) + ) + + def compute_loss(self, model, inputs): + labels = inputs.pop("labels") + outputs = model(**inputs, use_cache=False) + logits = outputs[0] + return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id) + + def _compute_loss(self, logits, labels, ignore_index): + if self.args.label_smoothing == 0: + # Same behavior as modeling_bart.py + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) + assert logits.shape[-1] == self.model.config.vocab_size + loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + else: + lprobs = torch.nn.functional.log_softmax(logits, dim=-1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index + ) + return loss + + def prediction_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + A tuple with the loss, logits and labels (each being optional). + """ + inputs = self._prepare_inputs(inputs) + + max_length = ( + model.config.max_generate_length + if hasattr(model.config, "max_generate_length") + else model.config.max_position_embeddings + ) + + with torch.no_grad(): + if self.args.predict_with_generate and not self.args.prediction_loss_only: + generated_tokens = model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=True, + num_beams=model.config.num_beams, + max_length=max_length, + ) + # in case the batch is shorter than max length, the output should be padded + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, max_length, model.config.pad_token_id + ) + + labels_out = inputs.get("labels") + outputs = model(**inputs) + logits = outputs[1] + loss = self._compute_loss(logits, labels_out, model.config.pad_token_id) + loss = loss.mean().item() + if self.args.prediction_loss_only: + logits = None + else: + logits = generated_tokens if self.args.predict_with_generate else logits + + if self.args.prediction_loss_only: + return (loss, None, None) + + labels_out = labels_out.detach() + labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id) + return (loss, logits.detach(), labels) + + def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id): + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py new file mode 100644 index 0000000000..2863a5aa5f --- /dev/null +++ b/examples/seq2seq/test_finetune_trainer.py @@ -0,0 +1,96 @@ +import os +import sys +import tempfile +from unittest.mock import patch + +from transformers import BartForConditionalGeneration, MarianMTModel +from transformers.testing_utils import slow + +from .finetune_trainer import main +from .test_seq2seq_examples import MBART_TINY +from .utils import load_json + + +MODEL_NAME = MBART_TINY +# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1" +MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" + + +@slow +def test_model_download(): + """This warms up the cache so that we can time the next test without including download time, which varies between machines.""" + BartForConditionalGeneration.from_pretrained(MODEL_NAME) + MarianMTModel.from_pretrained(MARIAN_MODEL) + + +@slow +def test_finetune_trainer(): + data_dir = "examples/seq2seq/test_data/wmt_en_ro" + output_dir = tempfile.mkdtemp(prefix="marian_output") + max_len = "128" + num_train_epochs = 4 + eval_steps = 2 + argv = [ + "--model_name_or_path", + MARIAN_MODEL, + "--data_dir", + data_dir, + "--output_dir", + output_dir, + "--overwrite_output_dir", + "--n_train", + "8", + "--n_val", + "8", + "--max_source_length", + max_len, + "--max_target_length", + max_len, + "--val_max_target_length", + max_len, + "--do_train", + "--do_eval", + "--do_predict", + "--num_train_epochs", + str(num_train_epochs), + "--per_device_train_batch_size", + "4", + "--per_device_eval_batch_size", + "4", + "--learning_rate", + "3e-4", + "--warmup_steps", + "8", + "--evaluate_during_training", + "--predict_with_generate", + "--logging_steps", + 0, + "--save_steps", + str(eval_steps), + "--eval_steps", + str(eval_steps), + "--sortish_sampler", + "--label_smoothing", + "0.1", + "--task", + "translation", + ] + + testargs = ["finetune_trainer.py"] + argv + with patch.object(sys, "argv", testargs): + main() + + # Check metrics + logs = load_json(os.path.join(output_dir, "log_history.json")) + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] + first_step_stats = eval_metrics[0] + last_step_stats = eval_metrics[-1] + + assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing + assert isinstance(last_step_stats["eval_bleu"], float) + + # test if do_predict saves generations and metrics + contents = os.listdir(output_dir) + contents = {os.path.basename(p) for p in contents} + assert "test_generations.txt" in contents + assert "test_results.json" in contents diff --git a/examples/seq2seq/xla_spawn.py b/examples/seq2seq/xla_spawn.py new file mode 100644 index 0000000000..0889e57afc --- /dev/null +++ b/examples/seq2seq/xla_spawn.py @@ -0,0 +1,72 @@ +""" +A simple launcher script for TPU training + +Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py + +:: + >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +""" + + +import importlib +import sys +from argparse import REMAINDER, ArgumentParser +from pathlib import Path + +import torch_xla.distributed.xla_multiprocessing as xmp + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser( + description=( + "PyTorch TPU distributed training launch " + "helper utility that will spawn up " + "multiple distributed processes" + ) + ) + + # Optional arguments for the launch helper + parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).") + + # positional + parser.add_argument( + "training_script", + type=str, + help=( + "The full path to the single TPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script" + ), + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser.parse_args() + + +def main(): + args = parse_args() + + # Import training_script as a module. + script_fpath = Path(args.training_script) + sys.path.append(str(script_fpath.parent.resolve())) + mod_name = script_fpath.stem + mod = importlib.import_module(mod_name) + + # Patch sys.argv + sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)] + + xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores) + + +if __name__ == "__main__": + main()