From 181d778f83bf6e58c1d69a7599afb2bb9ceff21e Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:21:56 +0200 Subject: [PATCH] [`NllbMoe`] Update code to properly support loss computation (#25429) * update nllb_moe * fix * doc nits * nits * add a small test * ficup * remove adapted from --- .../models/nllb_moe/modeling_nllb_moe.py | 23 +++++++++++-------- .../models/nllb_moe/test_modeling_nllb_moe.py | 10 ++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index a7c02cdeba..cf2bdd5e52 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l return incremental_indices.long() + padding_idx -# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. @@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T Returns: The auxiliary loss. """ + if router_probs is None: + return 0 + num_experts = router_probs.shape[-1] # cast the expert indices to int64, otherwise one-hot encoding will fail @@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module): if self.is_sparse: hidden_states, router_states = self.ffn(hidden_states, attention_mask) else: - hidden_states = self.ffn(hidden_states) + # router_states set to None to track which layers have None gradients. + hidden_states, router_states = self.ffn(hidden_states), None + hidden_states = self.ff_dropout(hidden_states) hidden_states = residual + hidden_states @@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module): if self.is_sparse: hidden_states, router_states = self.ffn(hidden_states, attention_mask) else: - hidden_states = self.ffn(hidden_states) + hidden_states, router_states = self.ffn(hidden_states), None + hidden_states = self.ff_dropout(hidden_states) hidden_states = residual + hidden_states @@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): if output_router_logits: encoder_router_logits = outputs[-1] - decoder_router_logits = outputs[5 if output_attentions else 3] + decoder_router_logits = outputs[3 if output_attentions else 4] # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits) @@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): decoder_router_logits=outputs.decoder_router_logits, ) - # Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits def _unpack_router_logits(self, router_outputs): total_router_logits = [] total_expert_indexes = [] @@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): router_logits, expert_indexes = router_output total_router_logits.append(router_logits) total_expert_indexes.append(expert_indexes) - if len(total_expert_indexes) > 0: - total_router_logits = torch.cat(total_router_logits, dim=1) - if len(total_expert_indexes) > 0: - torch.cat(total_expert_indexes, dim=1) - return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) + + total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None + total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None + return total_router_logits, total_expert_indexes # Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation def prepare_inputs_for_generation( diff --git a/tests/models/nllb_moe/test_modeling_nllb_moe.py b/tests/models/nllb_moe/test_modeling_nllb_moe.py index 9311a01990..409db2207e 100644 --- a/tests/models/nllb_moe/test_modeling_nllb_moe.py +++ b/tests/models/nllb_moe/test_modeling_nllb_moe.py @@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + def test_get_loss(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_dict["output_router_logits"] = True + input_dict["labels"] = input_dict["input_ids"] + model = NllbMoeForConditionalGeneration(config).eval().to(torch_device) + out = model(**input_dict) + self.assertIsNotNone(out.loss) + self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1]) + self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0]) + @require_torch @require_sentencepiece