Avoid unnecessary device operations in loss computing (#36950)
* Avoid unnecessary tensor copy in loss computing * Add type
This commit is contained in:
@@ -24,7 +24,13 @@ from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
|
|||||||
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
from .loss_rt_detr import RTDetrForObjectDetectionLoss
|
||||||
|
|
||||||
|
|
||||||
def fixed_cross_entropy(source, target, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs):
|
def fixed_cross_entropy(
|
||||||
|
source: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||||
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
|
||||||
if reduction == "sum":
|
if reduction == "sum":
|
||||||
@@ -38,14 +44,13 @@ def ForCausalLMLoss(
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[int] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
shift_labels=None,
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
if shift_labels is None:
|
if shift_labels is None:
|
||||||
labels = labels.to(logits.device)
|
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
@@ -60,11 +65,15 @@ def ForCausalLMLoss(
|
|||||||
|
|
||||||
|
|
||||||
def ForMaskedLMLoss(
|
def ForMaskedLMLoss(
|
||||||
logits, labels, vocab_size: int, num_items_in_batch: Optional[int] = None, ignore_index: int = -100, **kwargs
|
logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
vocab_size: int,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
labels = labels.to(logits.device)
|
|
||||||
|
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
logits = logits.view(-1, vocab_size)
|
logits = logits.view(-1, vocab_size)
|
||||||
@@ -76,12 +85,12 @@ def ForMaskedLMLoss(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
|
||||||
num_labels = config.num_labels
|
num_labels = config.num_labels
|
||||||
if config.problem_type is None:
|
if config.problem_type is None:
|
||||||
if num_labels == 1:
|
if num_labels == 1:
|
||||||
config.problem_type = "regression"
|
config.problem_type = "regression"
|
||||||
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
|
||||||
config.problem_type = "single_label_classification"
|
config.problem_type = "single_label_classification"
|
||||||
else:
|
else:
|
||||||
config.problem_type = "multi_label_classification"
|
config.problem_type = "multi_label_classification"
|
||||||
@@ -90,15 +99,17 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
|
|||||||
if config.problem_type == "regression":
|
if config.problem_type == "regression":
|
||||||
loss_fct = MSELoss()
|
loss_fct = MSELoss()
|
||||||
if num_labels == 1:
|
if num_labels == 1:
|
||||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
return loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||||
else:
|
else:
|
||||||
loss = loss_fct(pooled_logits, labels)
|
return loss_fct(pooled_logits, labels)
|
||||||
elif config.problem_type == "single_label_classification":
|
if config.problem_type == "single_label_classification":
|
||||||
loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
|
||||||
elif config.problem_type == "multi_label_classification":
|
|
||||||
|
if config.problem_type == "multi_label_classification":
|
||||||
loss_fct = BCEWithLogitsLoss()
|
loss_fct = BCEWithLogitsLoss()
|
||||||
loss = loss_fct(pooled_logits, labels)
|
return loss_fct(pooled_logits, labels)
|
||||||
return loss
|
|
||||||
|
raise RuntimeError(f"Invalid problem type: {config.problem_type}")
|
||||||
|
|
||||||
|
|
||||||
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
|
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
|
||||||
@@ -120,7 +131,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi
|
|||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
def ForTokenClassification(logits, labels, config, **kwargs):
|
def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
logits = logits.view(-1, config.num_labels)
|
logits = logits.view(-1, config.num_labels)
|
||||||
labels = labels.view(-1).to(logits.device)
|
labels = labels.view(-1).to(logits.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user