From 72d363d9791c64eff21ed7535c690de2c636d508 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 1 Oct 2020 21:49:29 +0530 Subject: [PATCH] [examples/s2s] clean up finetune_trainer (#7509) --- examples/seq2seq/finetune_trainer.py | 108 ++------------------------- examples/seq2seq/seq2seq_trainer.py | 8 +- examples/seq2seq/utils.py | 96 +++++++++++++++++++++++- 3 files changed, 107 insertions(+), 105 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index baa0d5b70e..c1192c971f 100644 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -2,37 +2,29 @@ import logging import os import sys from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple - -import numpy as np -import torch +from typing import Optional from seq2seq_trainer import Seq2SeqTrainer from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, - BartTokenizer, - EvalPrediction, HfArgumentParser, MBartTokenizer, - T5Tokenizer, TrainingArguments, set_seed, ) -from transformers.modeling_bart import shift_tokens_right from transformers.trainer_utils import EvaluationStrategy from utils import ( LegacySeq2SeqDataset, + Seq2SeqDataCollator, Seq2SeqDataset, assert_all_frozen, - calculate_bleu, - calculate_rouge, + build_compute_metrics_fn, freeze_embeds, freeze_params, lmap, save_json, - trim_batch, use_task_specific_params, write_txt_file, ) @@ -41,66 +33,6 @@ from utils import ( logger = logging.getLogger(__name__) -class Seq2SeqDataCollator: - def __init__(self, tokenizer, data_args, tpu_num_cores=None): - self.tokenizer = tokenizer - self.pad_token_id = tokenizer.pad_token_id - assert self.pad_token_id is not None, "self.pad_token_id must be defined" - self.data_args = data_args - self.tpu_num_cores = tpu_num_cores - self.add_prefix_space = isinstance(tokenizer, BartTokenizer) - - def __call__(self, batch) -> Dict[str, torch.Tensor]: - if hasattr(self.tokenizer, "prepare_seq2seq_batch"): - batch = self._encode(batch) - input_ids, attention_mask, labels = ( - batch["input_ids"], - batch["attention_mask"], - batch["labels"], - ) - else: - input_ids = torch.stack([x["input_ids"] for x in batch]) - attention_mask = torch.stack([x["attention_mask"] for x in batch]) - labels = torch.stack([x["labels"] for x in batch]) - - labels = trim_batch(labels, self.pad_token_id) - input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) - - if isinstance(self.tokenizer, T5Tokenizer): - decoder_input_ids = self._shift_right_t5(labels) - else: - decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) - - batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "labels": labels, - } - return batch - - def _shift_right_t5(self, input_ids): - # shift inputs to the right - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.pad_token_id - return shifted_input_ids - - def _encode(self, batch) -> Dict[str, torch.Tensor]: - batch_encoding = self.tokenizer.prepare_seq2seq_batch( - [x["src_texts"] for x in batch], - src_lang=self.data_args.src_lang, - tgt_texts=[x["tgt_texts"] for x in batch], - tgt_lang=self.data_args.tgt_lang, - max_length=self.data_args.max_source_length, - max_target_length=self.data_args.max_target_length, - padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack - return_tensors="pt", - add_prefix_space=self.add_prefix_space, - ) - return batch_encoding.data - - @dataclass class Seq2SeqTrainingArguments(TrainingArguments): """ @@ -271,34 +203,6 @@ def main(): ), "mBart requires --tgt_lang and --src_lang" model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] - def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: - def non_pad_len(tokens: np.ndarray) -> int: - return np.count_nonzero(tokens != tokenizer.pad_token_id) - - def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: - pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) - label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) - pred_str = lmap(str.strip, pred_str) - label_str = lmap(str.strip, label_str) - return pred_str, label_str - - def summarization_metrics(pred: EvalPrediction) -> Dict: - pred_str, label_str = decode_pred(pred) - rouge: Dict = calculate_rouge(pred_str, label_str) - summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) - rouge.update({"gen_len": summ_len}) - return rouge - - def translation_metrics(pred: EvalPrediction) -> Dict: - pred_str, label_str = decode_pred(pred) - bleu: Dict = calculate_bleu(pred_str, label_str) - gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) - bleu.update({"gen_len": gen_len}) - return bleu - - compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics - return compute_metrics_fn - if model_args.freeze_embeds: freeze_embeds(model) if model_args.freeze_encoder: @@ -349,13 +253,17 @@ def main(): ) # Initialize our Trainer + compute_metrics_fn = ( + build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None + ) trainer = Seq2SeqTrainer( model=model, + config=config, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), - compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, + compute_metrics=compute_metrics_fn, data_args=data_args, ) diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 885e7263f8..1f8d7fb415 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -20,11 +20,13 @@ logger = logging.getLogger(__name__) class Seq2SeqTrainer(Trainer): - def __init__(self, data_args, *args, **kwargs): + def __init__(self, config, data_args, *args, **kwargs): super().__init__(*args, **kwargs) + self.config = config self.data_args = data_args self.max_gen_length = data_args.val_max_target_length - self.pad_token_id = self.model.config.pad_token_id + self.pad_token_id = self.config.pad_token_id + self.vocab_size = self.config.vocab_size def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): @@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer): if self.args.label_smoothing == 0: # Same behavior as modeling_bart.py loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) - assert logits.shape[-1] == self.model.config.vocab_size + assert logits.shape[-1] == self.vocab_size loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: lprobs = torch.nn.functional.log_softmax(logits, dim=-1) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index f64f0104cb..18928e3dd3 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -7,7 +7,7 @@ import pickle import socket from logging import getLogger from pathlib import Path -from typing import Callable, Dict, Iterable, List, Union +from typing import Callable, Dict, Iterable, List, Tuple, Union import git import numpy as np @@ -19,8 +19,9 @@ from torch import nn from torch.utils.data import Dataset, Sampler from sentence_splitter import add_newline_to_end_of_each_sentence -from transformers import BartTokenizer +from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer from transformers.file_utils import cached_property +from transformers.modeling_bart import shift_tokens_right try: @@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} +def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: + def non_pad_len(tokens: np.ndarray) -> int: + return np.count_nonzero(tokens != tokenizer.pad_token_id) + + def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: + pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + pred_str = lmap(str.strip, pred_str) + label_str = lmap(str.strip, label_str) + return pred_str, label_str + + def summarization_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + rouge: Dict = calculate_rouge(pred_str, label_str) + summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) + rouge.update({"gen_len": summ_len}) + return rouge + + def translation_metrics(pred: EvalPrediction) -> Dict: + pred_str, label_str = decode_pred(pred) + bleu: Dict = calculate_bleu(pred_str, label_str) + gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) + bleu.update({"gen_len": gen_len}) + return bleu + + compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics + return compute_metrics_fn + + def trim_batch( input_ids, pad_token_id, @@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): return batch_encoding +class Seq2SeqDataCollator: + def __init__(self, tokenizer, data_args, tpu_num_cores=None): + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + assert ( + self.pad_token_id is not None + ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." + self.data_args = data_args + self.tpu_num_cores = tpu_num_cores + self.add_prefix_space = isinstance(tokenizer, BartTokenizer) + + def __call__(self, batch) -> Dict[str, torch.Tensor]: + if hasattr(self.tokenizer, "prepare_seq2seq_batch"): + batch = self._encode(batch) + input_ids, attention_mask, labels = ( + batch["input_ids"], + batch["attention_mask"], + batch["labels"], + ) + else: + input_ids = torch.stack([x["input_ids"] for x in batch]) + attention_mask = torch.stack([x["attention_mask"] for x in batch]) + labels = torch.stack([x["labels"] for x in batch]) + + labels = trim_batch(labels, self.pad_token_id) + input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) + + if isinstance(self.tokenizer, T5Tokenizer): + decoder_input_ids = self._shift_right_t5(labels) + else: + decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "labels": labels, + } + return batch + + def _shift_right_t5(self, input_ids): + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.pad_token_id + return shifted_input_ids + + def _encode(self, batch) -> Dict[str, torch.Tensor]: + batch_encoding = self.tokenizer.prepare_seq2seq_batch( + [x["src_texts"] for x in batch], + src_lang=self.data_args.src_lang, + tgt_texts=[x["tgt_texts"] for x in batch], + tgt_lang=self.data_args.tgt_lang, + max_length=self.data_args.max_source_length, + max_target_length=self.data_args.max_target_length, + padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack + return_tensors="pt", + add_prefix_space=self.add_prefix_space, + ) + return batch_encoding.data + + class SortishSampler(Sampler): "Go through the text data by order of src length with a bit of randomness. From fastai repo."