MixtralSparseMoeBlock: add gate jitter (#29865)
This commit adds gate jitter to MixtralSparseMoeBlock's input data before passing it through the MoE layer, if turned on.
This commit is contained in:
@@ -92,6 +92,8 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
allow the model to output the auxiliary loss. See [here]() for more details
|
allow the model to output the auxiliary loss. See [here]() for more details
|
||||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||||
The aux loss factor for the total loss.
|
The aux loss factor for the total loss.
|
||||||
|
router_jitter_noise (`float`, *optional*, defaults to 0.0):
|
||||||
|
Amount of noise to add to the router.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import MixtralModel, MixtralConfig
|
>>> from transformers import MixtralModel, MixtralConfig
|
||||||
@@ -133,6 +135,7 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
num_local_experts=8,
|
num_local_experts=8,
|
||||||
output_router_logits=False,
|
output_router_logits=False,
|
||||||
router_aux_loss_coef=0.001,
|
router_aux_loss_coef=0.001,
|
||||||
|
router_jitter_noise=0.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -159,6 +162,7 @@ class MixtralConfig(PretrainedConfig):
|
|||||||
self.num_local_experts = num_local_experts
|
self.num_local_experts = num_local_experts
|
||||||
self.output_router_logits = output_router_logits
|
self.output_router_logits = output_router_logits
|
||||||
self.router_aux_loss_coef = router_aux_loss_coef
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
self.router_jitter_noise = router_jitter_noise
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
|||||||
@@ -837,9 +837,14 @@ class MixtralSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
||||||
|
|
||||||
|
# Jitter parameters
|
||||||
|
self.jitter_noise = config.router_jitter_noise
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
""" """
|
""" """
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
if self.training and self.jitter_noise > 0:
|
||||||
|
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
class MixtralModelTester:
|
class MixtralModelTester:
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
@@ -69,6 +68,7 @@ class MixtralModelTester:
|
|||||||
num_choices=4,
|
num_choices=4,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
router_jitter_noise=0.1,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -94,6 +94,7 @@ class MixtralModelTester:
|
|||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.router_jitter_noise = router_jitter_noise
|
||||||
|
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -137,6 +138,7 @@ class MixtralModelTester:
|
|||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
num_experts_per_tok=2,
|
num_experts_per_tok=2,
|
||||||
num_local_experts=2,
|
num_local_experts=2,
|
||||||
|
router_jitter_noise=self.router_jitter_noise,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
|
||||||
|
|||||||
Reference in New Issue
Block a user