[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)
* Make Seq2Seq Trainer more similar to Trainer * fix typo * fix seq2seq trainer * remove from tests * remove lock * remove train files * delete test files * correct typo * check at init * make sure trainer is not slowed down on TPU * correct isort * remove use cache * fix use cache * add last use chache = false
This commit is contained in:
committed by
GitHub
parent
59b5953d89
commit
3c682ea15c
@@ -16,7 +16,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import EvaluationStrategy
|
from transformers.trainer_utils import EvaluationStrategy
|
||||||
from utils import (
|
from utils import (
|
||||||
LegacySeq2SeqDataset,
|
|
||||||
Seq2SeqDataCollator,
|
Seq2SeqDataCollator,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
@@ -138,6 +137,10 @@ class DataTrainingArguments:
|
|||||||
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
|
||||||
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
|
||||||
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
|
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
|
||||||
|
ignore_pad_token_for_loss: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -223,7 +226,7 @@ def main():
|
|||||||
freeze_params(model.get_encoder())
|
freeze_params(model.get_encoder())
|
||||||
assert_all_frozen(model.get_encoder())
|
assert_all_frozen(model.get_encoder())
|
||||||
|
|
||||||
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
dataset_class = Seq2SeqDataset
|
||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = (
|
train_dataset = (
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import logging
|
import copy
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import PreTrainedModel, Trainer, logging
|
||||||
from transformers.configuration_fsmt import FSMTConfig
|
from transformers.configuration_fsmt import FSMTConfig
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
from transformers.optimization import (
|
from transformers.optimization import (
|
||||||
@@ -27,7 +27,7 @@ except ImportError:
|
|||||||
from utils import label_smoothed_nll_loss
|
from utils import label_smoothed_nll_loss
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
arg_to_scheduler = {
|
arg_to_scheduler = {
|
||||||
"linear": get_linear_schedule_with_warmup,
|
"linear": get_linear_schedule_with_warmup,
|
||||||
@@ -41,13 +41,25 @@ arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
|||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Trainer):
|
class Seq2SeqTrainer(Trainer):
|
||||||
def __init__(self, config, data_args, *args, **kwargs):
|
def __init__(self, config=None, data_args=None, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.config = config
|
|
||||||
|
if config is None:
|
||||||
|
assert isinstance(
|
||||||
|
self.model, PreTrainedModel
|
||||||
|
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
|
||||||
|
self.config = self._actual_model(self.model).config
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.max_gen_length = data_args.val_max_target_length
|
|
||||||
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||||
|
|
||||||
|
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
|
||||||
|
assert (
|
||||||
|
self.config.pad_token_id is not None
|
||||||
|
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
Setup the optimizer and the learning rate scheduler.
|
Setup the optimizer and the learning rate scheduler.
|
||||||
@@ -114,23 +126,31 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
else DistributedSampler(self.train_dataset)
|
else DistributedSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_loss(self, model, inputs):
|
def _compute_loss(self, model, inputs):
|
||||||
labels = inputs.pop("labels")
|
inputs = copy.deepcopy(inputs)
|
||||||
outputs = model(**inputs, use_cache=False)
|
|
||||||
logits = outputs[0]
|
|
||||||
return self._compute_loss(logits, labels)
|
|
||||||
|
|
||||||
def _compute_loss(self, logits, labels):
|
|
||||||
if self.args.label_smoothing == 0:
|
if self.args.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py
|
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
# force training to ignore pad token
|
||||||
assert logits.shape[-1] == self.vocab_size
|
labels = inputs.pop("labels")
|
||||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||||
|
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
|
else:
|
||||||
|
# compute usual loss via models
|
||||||
|
loss, logits = model(**inputs, use_cache=False)[:2]
|
||||||
else:
|
else:
|
||||||
|
# compute label smoothed loss
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
loss, nll_loss = label_smoothed_nll_loss(
|
loss, _ = label_smoothed_nll_loss(
|
||||||
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
||||||
)
|
)
|
||||||
|
return loss, logits
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs):
|
||||||
|
loss, _ = self._compute_loss(model, inputs)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
@@ -158,31 +178,37 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": self.data_args.val_max_target_length
|
||||||
|
if self.data_args is not None
|
||||||
|
else self.config.max_length,
|
||||||
|
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
|
||||||
|
}
|
||||||
|
generated_tokens = model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
if self.config.pad_token_id is not None:
|
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
# compute loss on predict data
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
loss, logits = self._compute_loss(model, inputs)
|
||||||
generated_tokens = model.generate(
|
|
||||||
inputs["input_ids"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
use_cache=True,
|
|
||||||
num_beams=self.data_args.eval_beams,
|
|
||||||
max_length=self.max_gen_length,
|
|
||||||
)
|
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
|
|
||||||
|
|
||||||
labels_out = inputs.get("labels")
|
loss = loss.mean().detach()
|
||||||
# Call forward again to get loss # TODO: avoidable?
|
if self.args.prediction_loss_only:
|
||||||
outputs = model(**inputs, use_cache=False)
|
return (loss, None, None)
|
||||||
loss = self._compute_loss(outputs[1], labels_out)
|
|
||||||
loss = loss.mean().detach()
|
|
||||||
if self.args.prediction_loss_only:
|
|
||||||
return (loss, None, None)
|
|
||||||
|
|
||||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||||
|
|
||||||
labels_out = labels_out.detach()
|
labels = inputs["labels"]
|
||||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
|
if self.config.pad_token_id is not None:
|
||||||
return (loss, logits.detach(), labels)
|
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
|
||||||
|
|
||||||
|
return (loss, logits, labels)
|
||||||
|
|
||||||
def _pad_tensors_to_max_len(self, tensor, max_length):
|
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||||
padded_tensor = self.config.pad_token_id * torch.ones(
|
padded_tensor = self.config.pad_token_id * torch.ones(
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
|
||||||
|
from transformers.file_utils import is_datasets_available
|
||||||
from transformers.testing_utils import TestCasePlus, slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .finetune_trainer import main
|
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||||
|
from .seq2seq_trainer import Seq2SeqTrainer
|
||||||
from .test_seq2seq_examples import MBART_TINY
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
from .utils import execute_async_std
|
from .utils import execute_async_std
|
||||||
|
|
||||||
@@ -50,6 +52,117 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
assert "test_generations.txt" in contents
|
assert "test_generations.txt" in contents
|
||||||
assert "test_results.json" in contents
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_finetune_bert2bert(self):
|
||||||
|
if not is_datasets_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||||
|
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||||
|
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||||
|
|
||||||
|
train_dataset = train_dataset.select(range(32))
|
||||||
|
val_dataset = val_dataset.select(range(16))
|
||||||
|
|
||||||
|
rouge = datasets.load_metric("rouge")
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
def _map_to_encoder_decoder_inputs(batch):
|
||||||
|
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||||
|
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
||||||
|
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
||||||
|
batch["input_ids"] = inputs.input_ids
|
||||||
|
batch["attention_mask"] = inputs.attention_mask
|
||||||
|
|
||||||
|
batch["decoder_input_ids"] = outputs.input_ids
|
||||||
|
batch["labels"] = outputs.input_ids.copy()
|
||||||
|
batch["labels"] = [
|
||||||
|
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
||||||
|
]
|
||||||
|
batch["decoder_attention_mask"] = outputs.attention_mask
|
||||||
|
|
||||||
|
assert all([len(x) == 512 for x in inputs.input_ids])
|
||||||
|
assert all([len(x) == 128 for x in outputs.input_ids])
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _compute_metrics(pred):
|
||||||
|
labels_ids = pred.label_ids
|
||||||
|
pred_ids = pred.predictions
|
||||||
|
|
||||||
|
# all unnecessary tokens are removed
|
||||||
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||||
|
"rouge2"
|
||||||
|
].mid
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rouge2_precision": round(rouge_output.precision, 4),
|
||||||
|
"rouge2_recall": round(rouge_output.recall, 4),
|
||||||
|
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
# map train dataset
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_map_to_encoder_decoder_inputs,
|
||||||
|
batched=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
remove_columns=["article", "highlights"],
|
||||||
|
)
|
||||||
|
train_dataset.set_format(
|
||||||
|
type="torch",
|
||||||
|
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# same for validation dataset
|
||||||
|
val_dataset = val_dataset.map(
|
||||||
|
_map_to_encoder_decoder_inputs,
|
||||||
|
batched=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
remove_columns=["article", "highlights"],
|
||||||
|
)
|
||||||
|
val_dataset.set_format(
|
||||||
|
type="torch",
|
||||||
|
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
per_device_eval_batch_size=batch_size,
|
||||||
|
predict_with_generate=True,
|
||||||
|
evaluate_during_training=True,
|
||||||
|
do_train=True,
|
||||||
|
do_eval=True,
|
||||||
|
warmup_steps=0,
|
||||||
|
eval_steps=2,
|
||||||
|
logging_steps=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=bert2bert,
|
||||||
|
args=training_args,
|
||||||
|
compute_metrics=_compute_metrics,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# start training
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||||
|
|
||||||
# XXX: remove hardcoded path
|
# XXX: remove hardcoded path
|
||||||
|
|||||||
Reference in New Issue
Block a user