From 8c5e29bad510749f7b85c60197b8c3ebbe82192f Mon Sep 17 00:00:00 2001 From: cyyever Date: Thu, 27 Mar 2025 22:45:14 +0800 Subject: [PATCH] Avoid unnecessary device operations in loss computing (#36950) * Avoid unnecessary tensor copy in loss computing * Add type --- src/transformers/loss/loss_utils.py | 43 ++++++++++++++++++----------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 56cc7a1ebd..0e052aed6a 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -24,7 +24,13 @@ from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss -def fixed_cross_entropy(source, target, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs): +def fixed_cross_entropy( + source: torch.Tensor, + target: torch.Tensor, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + **kwargs, +) -> torch.Tensor: reduction = "sum" if num_items_in_batch is not None else "mean" loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": @@ -38,14 +44,13 @@ def ForCausalLMLoss( vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, - shift_labels=None, + shift_labels: Optional[torch.Tensor] = None, **kwargs, -): +) -> torch.Tensor: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() if shift_labels is None: - labels = labels.to(logits.device) # Shift so that tokens < n predict n labels = nn.functional.pad(labels, (0, 1), value=ignore_index) shift_labels = labels[..., 1:].contiguous() @@ -60,11 +65,15 @@ def ForCausalLMLoss( def ForMaskedLMLoss( - logits, labels, vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs + logits: torch.Tensor, + labels: torch.Tensor, + vocab_size: int, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + **kwargs, ): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() - labels = labels.to(logits.device) # Flatten the tokens logits = logits.view(-1, vocab_size) @@ -76,12 +85,12 @@ def ForMaskedLMLoss( return loss -def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): +def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor: num_labels = config.num_labels if config.problem_type is None: if num_labels == 1: config.problem_type = "regression" - elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)): config.problem_type = "single_label_classification" else: config.problem_type = "multi_label_classification" @@ -90,15 +99,17 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): if config.problem_type == "regression": loss_fct = MSELoss() if num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + return loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: - loss = loss_fct(pooled_logits, labels) - elif config.problem_type == "single_label_classification": - loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) - elif config.problem_type == "multi_label_classification": + return loss_fct(pooled_logits, labels) + if config.problem_type == "single_label_classification": + return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) + + if config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - return loss + return loss_fct(pooled_logits, labels) + + raise RuntimeError(f"Invalid problem type: {config.problem_type}") def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs): @@ -120,7 +131,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi return total_loss -def ForTokenClassification(logits, labels, config, **kwargs): +def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, config.num_labels) labels = labels.view(-1).to(logits.device)