[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)

* Make Seq2Seq Trainer more similar to Trainer

* fix typo

* fix seq2seq trainer

* remove from tests

* remove lock

* remove train files

* delete test files

* correct typo

* check at init

* make sure trainer is not slowed down on TPU

* correct isort

* remove use cache

* fix use cache

* add last use chache = false
This commit is contained in:
Patrick von Platen
2020-10-23 23:05:51 +02:00
committed by GitHub
parent 59b5953d89
commit 3c682ea15c
3 changed files with 185 additions and 43 deletions

View File

@@ -1,11 +1,11 @@
import logging
import copy
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers import PreTrainedModel, Trainer, logging
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import (
@@ -27,7 +27,7 @@ except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
@@ -41,13 +41,25 @@ arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
def __init__(self, config=None, data_args=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
if config is None:
assert isinstance(
self.model, PreTrainedModel
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
self.config = self._actual_model(self.model).config
else:
self.config = config
self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
assert (
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
@@ -114,23 +126,31 @@ class Seq2SeqTrainer(Trainer):
else DistributedSampler(self.train_dataset)
)
def compute_loss(self, model, inputs):
labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False)
logits = outputs[0]
return self._compute_loss(logits, labels)
def _compute_loss(self, logits, labels):
def _compute_loss(self, model, inputs):
inputs = copy.deepcopy(inputs)
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, use_cache=False)[:2]
else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
return loss, logits
def compute_loss(self, model, inputs):
loss, _ = self._compute_loss(model, inputs)
return loss
def prediction_step(
@@ -158,31 +178,37 @@ class Seq2SeqTrainer(Trainer):
"""
inputs = self._prepare_inputs(inputs)
if self.args.predict_with_generate and not self.args.prediction_loss_only:
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if self.config.pad_token_id is not None:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
# compute loss on predict data
with torch.no_grad():
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=self.data_args.eval_beams,
max_length=self.max_gen_length,
)
# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
loss, logits = self._compute_loss(model, inputs)
labels_out = inputs.get("labels")
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
logits = generated_tokens if self.args.predict_with_generate else logits
labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
return (loss, logits.detach(), labels)
labels = inputs["labels"]
if self.config.pad_token_id is not None:
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
return (loss, logits, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
padded_tensor = self.config.pad_token_id * torch.ones(