[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)
* Make Seq2Seq Trainer more similar to Trainer * fix typo * fix seq2seq trainer * remove from tests * remove lock * remove train files * delete test files * correct typo * check at init * make sure trainer is not slowed down on TPU * correct isort * remove use cache * fix use cache * add last use chache = false
This commit is contained in:
committed by
GitHub
parent
59b5953d89
commit
3c682ea15c
@@ -5,12 +5,14 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import execute_async_std
|
||||
|
||||
@@ -50,6 +52,117 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
@slow
|
||||
def test_finetune_bert2bert(self):
|
||||
if not is_datasets_available():
|
||||
return
|
||||
|
||||
import datasets
|
||||
|
||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||
|
||||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||
|
||||
train_dataset = train_dataset.select(range(32))
|
||||
val_dataset = val_dataset.select(range(16))
|
||||
|
||||
rouge = datasets.load_metric("rouge")
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def _map_to_encoder_decoder_inputs(batch):
|
||||
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
||||
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
||||
batch["input_ids"] = inputs.input_ids
|
||||
batch["attention_mask"] = inputs.attention_mask
|
||||
|
||||
batch["decoder_input_ids"] = outputs.input_ids
|
||||
batch["labels"] = outputs.input_ids.copy()
|
||||
batch["labels"] = [
|
||||
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
||||
]
|
||||
batch["decoder_attention_mask"] = outputs.attention_mask
|
||||
|
||||
assert all([len(x) == 512 for x in inputs.input_ids])
|
||||
assert all([len(x) == 128 for x in outputs.input_ids])
|
||||
|
||||
return batch
|
||||
|
||||
def _compute_metrics(pred):
|
||||
labels_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
|
||||
# all unnecessary tokens are removed
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||
"rouge2"
|
||||
].mid
|
||||
|
||||
return {
|
||||
"rouge2_precision": round(rouge_output.precision, 4),
|
||||
"rouge2_recall": round(rouge_output.recall, 4),
|
||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||
}
|
||||
|
||||
# map train dataset
|
||||
train_dataset = train_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
train_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
# same for validation dataset
|
||||
val_dataset = val_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
val_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
output_dir=output_dir,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
predict_with_generate=True,
|
||||
evaluate_during_training=True,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
warmup_steps=0,
|
||||
eval_steps=2,
|
||||
logging_steps=2,
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=bert2bert,
|
||||
args=training_args,
|
||||
compute_metrics=_compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
)
|
||||
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
|
||||
# XXX: remove hardcoded path
|
||||
|
||||
Reference in New Issue
Block a user