Change model outputs types to self-document outputs (#5438)
* [WIP] Proposal for model outputs * All Bert models * Make CI green maybe? * Fix ONNX test * Isolate ModelOutput from pt and tf * Formatting * Add Electra models * Auto-generate docstrings from outputs * Add TF outputs * Add some BERT models * Revert TF side * Remove last traces of TF changes * Fail with a clear error message * Add Albert and work through Bart * Add CTRL and DistilBert * Formatting * Progress on Bart * Renames and finish Bart * Formatting * Fix last test * Add DPR * Finish Electra and add FlauBERT * Add GPT2 * Add Longformer * Add MMBT * Add MobileBert * Add GPT * Formatting * Add Reformer * Add Roberta * Add T5 * Add Transformer XL * Fix test * Add XLM + fix XLMForTokenClassification * Style + XLMRoberta * Add XLNet * Formatting * Add doc of return_tuple arg
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,7 @@ from .file_utils import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ModelOutput,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
@@ -941,6 +943,35 @@ class PoolerAnswerClass(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SquadHeadOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of question answering models using a :obj:`SquadHead`.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_top_log_probs: Optional[torch.FloatTensor] = None
|
||||
start_top_index: Optional[torch.LongTensor] = None
|
||||
end_top_log_probs: Optional[torch.FloatTensor] = None
|
||||
end_top_index: Optional[torch.LongTensor] = None
|
||||
cls_logits: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
r""" A SQuAD head inspired by XLNet.
|
||||
|
||||
@@ -992,10 +1023,15 @@ class SQuADHead(nn.Module):
|
||||
self.answer_class = PoolerAnswerClass(config)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
||||
self,
|
||||
hidden_states,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
is_impossible=None,
|
||||
p_mask=None,
|
||||
return_tuple=False,
|
||||
):
|
||||
outputs = ()
|
||||
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
@@ -1021,7 +1057,7 @@ class SQuADHead(nn.Module):
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
return (total_loss,) if return_tuple else SquadHeadOutput(loss=total_loss)
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -1051,11 +1087,16 @@ class SQuADHead(nn.Module):
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
if return_tuple:
|
||||
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
|
||||
else:
|
||||
return SquadHeadOutput(
|
||||
start_top_log_probs=start_top_log_probs,
|
||||
start_top_index=start_top_index,
|
||||
end_top_log_probs=end_top_log_probs,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
|
||||
class SequenceSummary(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user