Avoid unnecessary device operations in loss computing (#36950)
* Avoid unnecessary tensor copy in loss computing * Add type
This commit is contained in:
@@ -24,7 +24,13 @@ from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
|
||||
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
||||
|
||||
|
||||
def fixed_cross_entropy(source, target, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs):
|
||||
def fixed_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
@@ -38,14 +44,13 @@ def ForCausalLMLoss(
|
||||
vocab_size: int,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels=None,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
|
||||
if shift_labels is None:
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
@@ -60,11 +65,15 @@ def ForCausalLMLoss(
|
||||
|
||||
|
||||
def ForMaskedLMLoss(
|
||||
logits, labels, vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
# Flatten the tokens
|
||||
logits = logits.view(-1, vocab_size)
|
||||
@@ -76,12 +85,12 @@ def ForMaskedLMLoss(
|
||||
return loss
|
||||
|
||||
|
||||
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
||||
def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
|
||||
num_labels = config.num_labels
|
||||
if config.problem_type is None:
|
||||
if num_labels == 1:
|
||||
config.problem_type = "regression"
|
||||
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
|
||||
config.problem_type = "single_label_classification"
|
||||
else:
|
||||
config.problem_type = "multi_label_classification"
|
||||
@@ -90,15 +99,17 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
||||
if config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
return loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif config.problem_type == "single_label_classification":
|
||||
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
||||
elif config.problem_type == "multi_label_classification":
|
||||
return loss_fct(pooled_logits, labels)
|
||||
if config.problem_type == "single_label_classification":
|
||||
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
||||
|
||||
if config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
return loss
|
||||
return loss_fct(pooled_logits, labels)
|
||||
|
||||
raise RuntimeError(f"Invalid problem type: {config.problem_type}")
|
||||
|
||||
|
||||
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
|
||||
@@ -120,7 +131,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
|
||||
return total_loss
|
||||
|
||||
|
||||
def ForTokenClassification(logits, labels, config, **kwargs):
|
||||
def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.view(-1, config.num_labels)
|
||||
labels = labels.view(-1).to(logits.device)
|
||||
|
||||
Reference in New Issue
Block a user