Fix get large model config for Switch Transformer encoder only tester (#32438)

This commit is contained in:
Francisco Kurucz
2024-08-06 07:48:32 -03:00
committed by GitHub
parent fb66ef8147
commit 438d06c95a

View File

@@ -770,7 +770,7 @@ class SwitchTransformersEncoderOnlyModelTester:
self.is_training = is_training self.is_training = is_training
def get_large_model_config(self): def get_large_model_config(self):
return SwitchTransformersConfig.from_pretrained("switch_base_8") return SwitchTransformersConfig.from_pretrained("google/switch-base-8")
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)