[examples/s2s] clean up finetune_trainer (#7509)

This commit is contained in:
Suraj Patil
2020-10-01 21:49:29 +05:30
committed by GitHub
parent bd2621583b
commit 72d363d979
3 changed files with 107 additions and 105 deletions

View File

@@ -2,37 +2,29 @@ import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from typing import Optional
from seq2seq_trainer import Seq2SeqTrainer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartTokenizer,
EvalPrediction,
HfArgumentParser,
MBartTokenizer,
T5Tokenizer,
TrainingArguments,
set_seed,
)
from transformers.modeling_bart import shift_tokens_right
from transformers.trainer_utils import EvaluationStrategy
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
build_compute_metrics_fn,
freeze_embeds,
freeze_params,
lmap,
save_json,
trim_batch,
use_task_specific_params,
write_txt_file,
)
@@ -41,66 +33,6 @@ from utils import (
logger = logging.getLogger(__name__)
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
"""
@@ -271,34 +203,6 @@ def main():
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
if model_args.freeze_embeds:
freeze_embeds(model)
if model_args.freeze_encoder:
@@ -349,13 +253,17 @@ def main():
)
# Initialize our Trainer
compute_metrics_fn = (
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
compute_metrics=compute_metrics_fn,
data_args=data_args,
)