Avoid unnecessary device operations in loss computing (#36950)

* Avoid unnecessary tensor copy in loss computing

* Add type
This commit is contained in:
cyyever
2025-03-27 22:45:14 +08:00
committed by GitHub
parent 471cf1de63
commit 8c5e29bad5

View File

@@ -24,7 +24,13 @@ from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss
def fixed_cross_entropy(source, target, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs): def fixed_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
reduction = "sum" if num_items_in_batch is not None else "mean" reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum": if reduction == "sum":
@@ -38,14 +44,13 @@ def ForCausalLMLoss(
vocab_size: int, vocab_size: int,
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[int] = None,
ignore_index: int = -100, ignore_index: int = -100,
shift_labels=None, shift_labels: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
): ) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues # Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float() logits = logits.float()
if shift_labels is None: if shift_labels is None:
labels = labels.to(logits.device)
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index) labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
@@ -60,11 +65,15 @@ def ForCausalLMLoss(
def ForMaskedLMLoss( def ForMaskedLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs logits: torch.Tensor,
labels: torch.Tensor,
vocab_size: int,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
**kwargs,
): ):
# Upcast to float if we need to compute the loss to avoid potential precision issues # Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float() logits = logits.float()
labels = labels.to(logits.device)
# Flatten the tokens # Flatten the tokens
logits = logits.view(-1, vocab_size) logits = logits.view(-1, vocab_size)
@@ -76,12 +85,12 @@ def ForMaskedLMLoss(
return loss return loss
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
num_labels = config.num_labels num_labels = config.num_labels
if config.problem_type is None: if config.problem_type is None:
if num_labels == 1: if num_labels == 1:
config.problem_type = "regression" config.problem_type = "regression"
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
config.problem_type = "single_label_classification" config.problem_type = "single_label_classification"
else: else:
config.problem_type = "multi_label_classification" config.problem_type = "multi_label_classification"
@@ -90,15 +99,17 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
if config.problem_type == "regression": if config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
if num_labels == 1: if num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) return loss_fct(pooled_logits.squeeze(), labels.squeeze())
else: else:
loss = loss_fct(pooled_logits, labels) return loss_fct(pooled_logits, labels)
elif config.problem_type == "single_label_classification": if config.problem_type == "single_label_classification":
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
elif config.problem_type == "multi_label_classification":
if config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels) return loss_fct(pooled_logits, labels)
return loss
raise RuntimeError(f"Invalid problem type: {config.problem_type}")
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs): def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
@@ -120,7 +131,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
return total_loss return total_loss
def ForTokenClassification(logits, labels, config, **kwargs): def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues # Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels) logits = logits.view(-1, config.num_labels)
labels = labels.view(-1).to(logits.device) labels = labels.view(-1).to(logits.device)