Seq2SeqTrainer (#6769)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Suraj Patil
2020-09-25 04:16:58 +05:30
committed by GitHub
parent d9d0f1140b
commit 9e68d075a4
11 changed files with 879 additions and 0 deletions

View File

@@ -0,0 +1,126 @@
import logging
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer import get_tpu_sampler
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def compute_loss(self, model, inputs):
labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False)
logits = outputs[0]
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
def _compute_loss(self, logits, labels, ignore_index):
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
assert logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
)
return loss
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
inputs = self._prepare_inputs(inputs)
max_length = (
model.config.max_generate_length
if hasattr(model.config, "max_generate_length")
else model.config.max_position_embeddings
)
with torch.no_grad():
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=model.config.num_beams,
max_length=max_length,
)
# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(
generated_tokens, max_length, model.config.pad_token_id
)
labels_out = inputs.get("labels")
outputs = model(**inputs)
logits = outputs[1]
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
loss = loss.mean().item()
if self.args.prediction_loss_only:
logits = None
else:
logits = generated_tokens if self.args.predict_with_generate else logits
if self.args.prediction_loss_only:
return (loss, None, None)
labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
return (loss, logits.detach(), labels)
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor