[examples/s2s] clean up finetune_trainer (#7509)
This commit is contained in:
@@ -2,37 +2,29 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from seq2seq_trainer import Seq2SeqTrainer
|
from seq2seq_trainer import Seq2SeqTrainer
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BartTokenizer,
|
|
||||||
EvalPrediction,
|
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
T5Tokenizer,
|
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
|
||||||
from transformers.trainer_utils import EvaluationStrategy
|
from transformers.trainer_utils import EvaluationStrategy
|
||||||
from utils import (
|
from utils import (
|
||||||
LegacySeq2SeqDataset,
|
LegacySeq2SeqDataset,
|
||||||
|
Seq2SeqDataCollator,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
build_compute_metrics_fn,
|
||||||
calculate_rouge,
|
|
||||||
freeze_embeds,
|
freeze_embeds,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
lmap,
|
lmap,
|
||||||
save_json,
|
save_json,
|
||||||
trim_batch,
|
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
write_txt_file,
|
write_txt_file,
|
||||||
)
|
)
|
||||||
@@ -41,66 +33,6 @@ from utils import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqDataCollator:
|
|
||||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
|
||||||
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
|
||||||
self.data_args = data_args
|
|
||||||
self.tpu_num_cores = tpu_num_cores
|
|
||||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
|
||||||
|
|
||||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
|
||||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
|
||||||
batch = self._encode(batch)
|
|
||||||
input_ids, attention_mask, labels = (
|
|
||||||
batch["input_ids"],
|
|
||||||
batch["attention_mask"],
|
|
||||||
batch["labels"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
|
||||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
|
||||||
labels = torch.stack([x["labels"] for x in batch])
|
|
||||||
|
|
||||||
labels = trim_batch(labels, self.pad_token_id)
|
|
||||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
if isinstance(self.tokenizer, T5Tokenizer):
|
|
||||||
decoder_input_ids = self._shift_right_t5(labels)
|
|
||||||
else:
|
|
||||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
|
||||||
|
|
||||||
batch = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"decoder_input_ids": decoder_input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def _shift_right_t5(self, input_ids):
|
|
||||||
# shift inputs to the right
|
|
||||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
||||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
|
||||||
shifted_input_ids[..., 0] = self.pad_token_id
|
|
||||||
return shifted_input_ids
|
|
||||||
|
|
||||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
|
||||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
|
||||||
[x["src_texts"] for x in batch],
|
|
||||||
src_lang=self.data_args.src_lang,
|
|
||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
|
||||||
tgt_lang=self.data_args.tgt_lang,
|
|
||||||
max_length=self.data_args.max_source_length,
|
|
||||||
max_target_length=self.data_args.max_target_length,
|
|
||||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
|
||||||
return_tensors="pt",
|
|
||||||
add_prefix_space=self.add_prefix_space,
|
|
||||||
)
|
|
||||||
return batch_encoding.data
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seq2SeqTrainingArguments(TrainingArguments):
|
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
@@ -271,34 +203,6 @@ def main():
|
|||||||
), "mBart requires --tgt_lang and --src_lang"
|
), "mBart requires --tgt_lang and --src_lang"
|
||||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||||
|
|
||||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
|
||||||
def non_pad_len(tokens: np.ndarray) -> int:
|
|
||||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
|
||||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
|
||||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
|
||||||
pred_str = lmap(str.strip, pred_str)
|
|
||||||
label_str = lmap(str.strip, label_str)
|
|
||||||
return pred_str, label_str
|
|
||||||
|
|
||||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
|
||||||
pred_str, label_str = decode_pred(pred)
|
|
||||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
|
||||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
|
||||||
rouge.update({"gen_len": summ_len})
|
|
||||||
return rouge
|
|
||||||
|
|
||||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
|
||||||
pred_str, label_str = decode_pred(pred)
|
|
||||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
|
||||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
|
||||||
bleu.update({"gen_len": gen_len})
|
|
||||||
return bleu
|
|
||||||
|
|
||||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
|
||||||
return compute_metrics_fn
|
|
||||||
|
|
||||||
if model_args.freeze_embeds:
|
if model_args.freeze_embeds:
|
||||||
freeze_embeds(model)
|
freeze_embeds(model)
|
||||||
if model_args.freeze_encoder:
|
if model_args.freeze_encoder:
|
||||||
@@ -349,13 +253,17 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
compute_metrics_fn = (
|
||||||
|
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
|
||||||
|
)
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
config=config,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
compute_metrics=compute_metrics_fn,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
def __init__(self, data_args, *args, **kwargs):
|
def __init__(self, config, data_args, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.config = config
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.max_gen_length = data_args.val_max_target_length
|
self.max_gen_length = data_args.val_max_target_length
|
||||||
self.pad_token_id = self.model.config.pad_token_id
|
self.pad_token_id = self.config.pad_token_id
|
||||||
|
self.vocab_size = self.config.vocab_size
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py
|
# Same behavior as modeling_bart.py
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||||
assert logits.shape[-1] == self.model.config.vocab_size
|
assert logits.shape[-1] == self.vocab_size
|
||||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pickle
|
|||||||
import socket
|
import socket
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List, Union
|
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -19,8 +19,9 @@ from torch import nn
|
|||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
|
||||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||||
from transformers import BartTokenizer
|
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
|||||||
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||||
|
|
||||||
|
|
||||||
|
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
||||||
|
def non_pad_len(tokens: np.ndarray) -> int:
|
||||||
|
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||||
|
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||||
|
pred_str = lmap(str.strip, pred_str)
|
||||||
|
label_str = lmap(str.strip, label_str)
|
||||||
|
return pred_str, label_str
|
||||||
|
|
||||||
|
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||||
|
pred_str, label_str = decode_pred(pred)
|
||||||
|
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||||
|
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||||
|
rouge.update({"gen_len": summ_len})
|
||||||
|
return rouge
|
||||||
|
|
||||||
|
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||||
|
pred_str, label_str = decode_pred(pred)
|
||||||
|
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||||
|
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||||
|
bleu.update({"gen_len": gen_len})
|
||||||
|
return bleu
|
||||||
|
|
||||||
|
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||||
|
return compute_metrics_fn
|
||||||
|
|
||||||
|
|
||||||
def trim_batch(
|
def trim_batch(
|
||||||
input_ids,
|
input_ids,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||||||
return batch_encoding
|
return batch_encoding
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqDataCollator:
|
||||||
|
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
|
assert (
|
||||||
|
self.pad_token_id is not None
|
||||||
|
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||||
|
self.data_args = data_args
|
||||||
|
self.tpu_num_cores = tpu_num_cores
|
||||||
|
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||||
|
|
||||||
|
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||||
|
batch = self._encode(batch)
|
||||||
|
input_ids, attention_mask, labels = (
|
||||||
|
batch["input_ids"],
|
||||||
|
batch["attention_mask"],
|
||||||
|
batch["labels"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
|
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||||
|
labels = torch.stack([x["labels"] for x in batch])
|
||||||
|
|
||||||
|
labels = trim_batch(labels, self.pad_token_id)
|
||||||
|
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
if isinstance(self.tokenizer, T5Tokenizer):
|
||||||
|
decoder_input_ids = self._shift_right_t5(labels)
|
||||||
|
else:
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _shift_right_t5(self, input_ids):
|
||||||
|
# shift inputs to the right
|
||||||
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||||
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||||
|
shifted_input_ids[..., 0] = self.pad_token_id
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
|
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
[x["src_texts"] for x in batch],
|
||||||
|
src_lang=self.data_args.src_lang,
|
||||||
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
|
tgt_lang=self.data_args.tgt_lang,
|
||||||
|
max_length=self.data_args.max_source_length,
|
||||||
|
max_target_length=self.data_args.max_target_length,
|
||||||
|
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||||
|
return_tensors="pt",
|
||||||
|
add_prefix_space=self.add_prefix_space,
|
||||||
|
)
|
||||||
|
return batch_encoding.data
|
||||||
|
|
||||||
|
|
||||||
class SortishSampler(Sampler):
|
class SortishSampler(Sampler):
|
||||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user