tests: Fix flaky test for NLLB-MoE (#22880)
* add test update and docs edits * docs edit suggestion
This commit is contained in:
@@ -43,9 +43,10 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArtZucker).
|
|||||||
The original code can be found [here](https://github.com/facebookresearch/fairseq).
|
The original code can be found [here](https://github.com/facebookresearch/fairseq).
|
||||||
|
|
||||||
## Implementation differences with SwitchTransformers
|
## Implementation differences with SwitchTransformers
|
||||||
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that blah blah blah blah.
|
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
|
||||||
In SwitchTransformers, once the masks are computed for each experts, we just index the current hidden_states with the routing mask, and feed the
|
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
|
||||||
correct tokens to the expert. However here, the implementation varies a lot as the fairseq repository used a different approach.
|
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
|
||||||
|
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
|
||||||
|
|
||||||
## Generating with NLLB-MoE
|
## Generating with NLLB-MoE
|
||||||
The avalable checkpoints requires around 350GB of storage. Make sure to use `accelerate` if you do not have enough RAM on your machine.
|
The avalable checkpoints requires around 350GB of storage. Make sure to use `accelerate` if you do not have enough RAM on your machine.
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class NllbMoeConfig(PretrainedConfig):
|
|||||||
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
||||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
|
|||||||
@@ -460,7 +460,7 @@ class NllbMoeSparseMLP(nn.Module):
|
|||||||
Attention mask. Can be in the causal form or not.
|
Attention mask. Can be in the causal form or not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_lenght, hidden_dim)`):
|
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
|
||||||
Updated hidden states
|
Updated hidden states
|
||||||
router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`):
|
router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`):
|
||||||
Needed for computing the loss
|
Needed for computing the loss
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import unittest
|
|||||||
|
|
||||||
from transformers import NllbMoeConfig, is_torch_available, set_seed
|
from transformers import NllbMoeConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -210,7 +209,7 @@ class NllbMoeModelTester:
|
|||||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
# test that outputs are equal for slice
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||||
model = NllbMoeModel(config=config).to(torch_device).eval()
|
model = NllbMoeModel(config=config).to(torch_device).eval()
|
||||||
@@ -290,10 +289,10 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||||
self.assertEqual(info["missing_keys"], [])
|
self.assertEqual(info["missing_keys"], [])
|
||||||
|
|
||||||
@is_flaky()
|
|
||||||
def test_decoder_model_past_with_large_inputs(self):
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
config.decoder_sparse_step = 0
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, inputs_dict)
|
||||||
|
|
||||||
def test_encoder_decoder_model_standalone(self):
|
def test_encoder_decoder_model_standalone(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user