[NllbMoe] Update code to properly support loss computation (#25429)
* update nllb_moe * fix * doc nits * nits * add a small test * ficup * remove adapted from
This commit is contained in:
@@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
|||||||
return incremental_indices.long() + padding_idx
|
return incremental_indices.long() + padding_idx
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel
|
|
||||||
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
||||||
r"""
|
r"""
|
||||||
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
||||||
@@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
|
|||||||
Returns:
|
Returns:
|
||||||
The auxiliary loss.
|
The auxiliary loss.
|
||||||
"""
|
"""
|
||||||
|
if router_probs is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
num_experts = router_probs.shape[-1]
|
num_experts = router_probs.shape[-1]
|
||||||
|
|
||||||
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
||||||
@@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module):
|
|||||||
if self.is_sparse:
|
if self.is_sparse:
|
||||||
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.ffn(hidden_states)
|
# router_states set to None to track which layers have None gradients.
|
||||||
|
hidden_states, router_states = self.ffn(hidden_states), None
|
||||||
|
|
||||||
hidden_states = self.ff_dropout(hidden_states)
|
hidden_states = self.ff_dropout(hidden_states)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module):
|
|||||||
if self.is_sparse:
|
if self.is_sparse:
|
||||||
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.ffn(hidden_states)
|
hidden_states, router_states = self.ffn(hidden_states), None
|
||||||
|
|
||||||
hidden_states = self.ff_dropout(hidden_states)
|
hidden_states = self.ff_dropout(hidden_states)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
|||||||
|
|
||||||
if output_router_logits:
|
if output_router_logits:
|
||||||
encoder_router_logits = outputs[-1]
|
encoder_router_logits = outputs[-1]
|
||||||
decoder_router_logits = outputs[5 if output_attentions else 3]
|
decoder_router_logits = outputs[3 if output_attentions else 4]
|
||||||
|
|
||||||
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
||||||
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
|
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
|
||||||
@@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
|||||||
decoder_router_logits=outputs.decoder_router_logits,
|
decoder_router_logits=outputs.decoder_router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits
|
|
||||||
def _unpack_router_logits(self, router_outputs):
|
def _unpack_router_logits(self, router_outputs):
|
||||||
total_router_logits = []
|
total_router_logits = []
|
||||||
total_expert_indexes = []
|
total_expert_indexes = []
|
||||||
@@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
|||||||
router_logits, expert_indexes = router_output
|
router_logits, expert_indexes = router_output
|
||||||
total_router_logits.append(router_logits)
|
total_router_logits.append(router_logits)
|
||||||
total_expert_indexes.append(expert_indexes)
|
total_expert_indexes.append(expert_indexes)
|
||||||
if len(total_expert_indexes) > 0:
|
|
||||||
total_router_logits = torch.cat(total_router_logits, dim=1)
|
total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
|
||||||
if len(total_expert_indexes) > 0:
|
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
|
||||||
torch.cat(total_expert_indexes, dim=1)
|
return total_router_logits, total_expert_indexes
|
||||||
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
|
|
||||||
|
|
||||||
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
|
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
|||||||
@@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
model.generate(input_ids, attention_mask=attention_mask)
|
model.generate(input_ids, attention_mask=attention_mask)
|
||||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
def test_get_loss(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
input_dict["output_router_logits"] = True
|
||||||
|
input_dict["labels"] = input_dict["input_ids"]
|
||||||
|
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
out = model(**input_dict)
|
||||||
|
self.assertIsNotNone(out.loss)
|
||||||
|
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
|
||||||
|
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
Reference in New Issue
Block a user