[s2sTrainer] test + code cleanup (#7467)
This commit is contained in:
@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, data_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.pad_token_id = self.model.config.pad_token_id
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
logits = outputs[0]
|
||||
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
|
||||
return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
|
||||
|
||||
def _compute_loss(self, logits, labels, ignore_index):
|
||||
if self.args.label_smoothing == 0:
|
||||
@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
|
||||
"""
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
max_length = (
|
||||
model.config.max_generate_length
|
||||
if hasattr(model.config, "max_generate_length")
|
||||
else model.config.max_position_embeddings
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.predict_with_generate and not self.args.prediction_loss_only:
|
||||
generated_tokens = model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
use_cache=True,
|
||||
num_beams=model.config.num_beams,
|
||||
max_length=max_length,
|
||||
num_beams=self.data_args.eval_beams,
|
||||
max_length=self.max_gen_length,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
generated_tokens = self._pad_tensors_to_max_len(
|
||||
generated_tokens, max_length, model.config.pad_token_id
|
||||
generated_tokens, self.max_gen_length, self.pad_token_id
|
||||
)
|
||||
|
||||
labels_out = inputs.get("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[1]
|
||||
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
|
||||
# Call forward again to get loss # TODO: avoidable?
|
||||
outputs = model(**inputs, use_cache=False)
|
||||
loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
|
||||
loss = loss.mean().item()
|
||||
if self.args.prediction_loss_only:
|
||||
logits = None
|
||||
else:
|
||||
logits = generated_tokens if self.args.predict_with_generate else logits
|
||||
return (loss, None, None)
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
|
||||
|
||||
labels_out = labels_out.detach()
|
||||
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
|
||||
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
|
||||
|
||||
Reference in New Issue
Block a user