[s2s] add config params like Dropout in Seq2SeqTrainingArguments (#7532)
This commit is contained in:
@@ -6,6 +6,7 @@ from torch import nn
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||
from transformers.trainer import get_tpu_sampler
|
||||
@@ -26,8 +27,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
self.config = config
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.pad_token_id = self.config.pad_token_id
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
"""
|
||||
@@ -87,18 +87,18 @@ class Seq2SeqTrainer(Trainer):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
logits = outputs[0]
|
||||
return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
|
||||
return self._compute_loss(logits, labels)
|
||||
|
||||
def _compute_loss(self, logits, labels, ignore_index):
|
||||
def _compute_loss(self, logits, labels):
|
||||
if self.args.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||
assert logits.shape[-1] == self.vocab_size
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
|
||||
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
||||
)
|
||||
return loss
|
||||
|
||||
@@ -137,14 +137,12 @@ class Seq2SeqTrainer(Trainer):
|
||||
max_length=self.max_gen_length,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
generated_tokens = self._pad_tensors_to_max_len(
|
||||
generated_tokens, self.max_gen_length, self.pad_token_id
|
||||
)
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
|
||||
|
||||
labels_out = inputs.get("labels")
|
||||
# Call forward again to get loss # TODO: avoidable?
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
|
||||
loss = self._compute_loss(outputs[1], labels_out)
|
||||
loss = loss.mean().item()
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
@@ -152,11 +150,11 @@ class Seq2SeqTrainer(Trainer):
|
||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||
|
||||
labels_out = labels_out.detach()
|
||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
|
||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||
padded_tensor = pad_token_id * torch.ones(
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||
padded_tensor = self.config.pad_token_id * torch.ones(
|
||||
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||
|
||||
Reference in New Issue
Block a user