Make more test models smaller (#25005)
* Make more test models tiny * Make more test models tiny * More models * More models
This commit is contained in:
@@ -100,6 +100,28 @@ class EsmFoldModelTester:
|
||||
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
esmfold_config = {
|
||||
"trunk": {
|
||||
"num_blocks": 2,
|
||||
"sequence_state_dim": 64,
|
||||
"pairwise_state_dim": 16,
|
||||
"sequence_head_width": 4,
|
||||
"pairwise_head_width": 4,
|
||||
"position_bins": 4,
|
||||
"chunk_size": 16,
|
||||
"structure_module": {
|
||||
"ipa_dim": 16,
|
||||
"num_angles": 7,
|
||||
"num_blocks": 2,
|
||||
"num_heads_ipa": 4,
|
||||
"pairwise_dim": 16,
|
||||
"resnet_dim": 16,
|
||||
"sequence_dim": 48,
|
||||
},
|
||||
},
|
||||
"fp16_esm": False,
|
||||
"lddt_head_hid_dim": 16,
|
||||
}
|
||||
config = EsmConfig(
|
||||
vocab_size=33,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -114,7 +136,7 @@ class EsmFoldModelTester:
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
is_folding_model=True,
|
||||
esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False},
|
||||
esmfold_config=esmfold_config,
|
||||
)
|
||||
return config
|
||||
|
||||
@@ -126,8 +148,8 @@ class EsmFoldModelTester:
|
||||
result = model(input_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3))
|
||||
self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2))
|
||||
self.parent.assertEqual(result.positions.shape, (2, self.batch_size, self.seq_length, 14, 3))
|
||||
self.parent.assertEqual(result.angles.shape, (2, self.batch_size, self.seq_length, 7, 2))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@@ -243,10 +265,6 @@ class EsmFoldModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class EsmModelIntegrationTest(TestCasePlus):
|
||||
|
||||
Reference in New Issue
Block a user