[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)
* Make Seq2Seq Trainer more similar to Trainer * fix typo * fix seq2seq trainer * remove from tests * remove lock * remove train files * delete test files * correct typo * check at init * make sure trainer is not slowed down on TPU * correct isort * remove use cache * fix use cache * add last use chache = false
This commit is contained in:
committed by
GitHub
parent
59b5953d89
commit
3c682ea15c
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers import PreTrainedModel, Trainer, logging
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.optimization import (
|
||||
@@ -27,7 +27,7 @@ except ImportError:
|
||||
from utils import label_smoothed_nll_loss
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
@@ -41,13 +41,25 @@ arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, config, data_args, *args, **kwargs):
|
||||
def __init__(self, config=None, data_args=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
|
||||
if config is None:
|
||||
assert isinstance(
|
||||
self.model, PreTrainedModel
|
||||
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
|
||||
self.config = self._actual_model(self.model).config
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||
|
||||
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
|
||||
assert (
|
||||
self.config.pad_token_id is not None
|
||||
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
"""
|
||||
Setup the optimizer and the learning rate scheduler.
|
||||
@@ -114,23 +126,31 @@ class Seq2SeqTrainer(Trainer):
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
logits = outputs[0]
|
||||
return self._compute_loss(logits, labels)
|
||||
|
||||
def _compute_loss(self, logits, labels):
|
||||
def _compute_loss(self, model, inputs):
|
||||
inputs = copy.deepcopy(inputs)
|
||||
if self.args.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||
assert logits.shape[-1] == self.vocab_size
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
||||
# force training to ignore pad token
|
||||
labels = inputs.pop("labels")
|
||||
logits = model(**inputs, use_cache=False)[0]
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
# compute usual loss via models
|
||||
loss, logits = model(**inputs, use_cache=False)[:2]
|
||||
else:
|
||||
# compute label smoothed loss
|
||||
labels = inputs.pop("labels")
|
||||
logits = model(**inputs, use_cache=False)[0]
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
loss, _ = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
||||
)
|
||||
return loss, logits
|
||||
|
||||
def compute_loss(self, model, inputs):
|
||||
loss, _ = self._compute_loss(model, inputs)
|
||||
return loss
|
||||
|
||||
def prediction_step(
|
||||
@@ -158,31 +178,37 @@ class Seq2SeqTrainer(Trainer):
|
||||
"""
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||
gen_kwargs = {
|
||||
"max_length": self.data_args.val_max_target_length
|
||||
if self.data_args is not None
|
||||
else self.config.max_length,
|
||||
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
|
||||
}
|
||||
generated_tokens = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
**gen_kwargs,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
if self.config.pad_token_id is not None:
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||
|
||||
# compute loss on predict data
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||
generated_tokens = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
use_cache=True,
|
||||
num_beams=self.data_args.eval_beams,
|
||||
max_length=self.max_gen_length,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
|
||||
loss, logits = self._compute_loss(model, inputs)
|
||||
|
||||
labels_out = inputs.get("labels")
|
||||
# Call forward again to get loss # TODO: avoidable?
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
loss = self._compute_loss(outputs[1], labels_out)
|
||||
loss = loss.mean().detach()
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
loss = loss.mean().detach()
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||
|
||||
labels_out = labels_out.detach()
|
||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
|
||||
return (loss, logits.detach(), labels)
|
||||
labels = inputs["labels"]
|
||||
if self.config.pad_token_id is not None:
|
||||
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||
padded_tensor = self.config.pad_token_id * torch.ones(
|
||||
|
||||
Reference in New Issue
Block a user