MixtralSparseMoeBlock: add gate jitter (#29865)
This commit adds gate jitter to MixtralSparseMoeBlock's input data before passing it through the MoE layer, if turned on.
This commit is contained in:
@@ -42,7 +42,6 @@ if is_torch_available():
|
||||
|
||||
|
||||
class MixtralModelTester:
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@@ -69,6 +68,7 @@ class MixtralModelTester:
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
router_jitter_noise=0.1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -94,6 +94,7 @@ class MixtralModelTester:
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
self.router_jitter_noise = router_jitter_noise
|
||||
|
||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
||||
def prepare_config_and_inputs(self):
|
||||
@@ -137,6 +138,7 @@ class MixtralModelTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
num_experts_per_tok=2,
|
||||
num_local_experts=2,
|
||||
router_jitter_noise=self.router_jitter_noise,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
|
||||
|
||||
Reference in New Issue
Block a user