[examples/s2s] clean up finetune_trainer (#7509)
This commit is contained in:
@@ -2,37 +2,29 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
T5Tokenizer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.trainer_utils import EvaluationStrategy
|
||||
from utils import (
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataCollator,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
build_compute_metrics_fn,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
lmap,
|
||||
save_json,
|
||||
trim_batch,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
@@ -41,66 +33,6 @@ from utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["labels"],
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||
labels = torch.stack([x["labels"] for x in batch])
|
||||
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.data_args.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.data_args.tgt_lang,
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
@@ -271,34 +203,6 @@ def main():
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
if model_args.freeze_encoder:
|
||||
@@ -349,13 +253,17 @@ def main():
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
compute_metrics_fn = (
|
||||
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
|
||||
)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||
compute_metrics=compute_metrics_fn,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user