Make more test models smaller (#25005)

* Make more test models tiny

* Make more test models tiny

* More models

* More models
This commit is contained in:
Sylvain Gugger
2023-07-24 10:08:47 -04:00
committed by GitHub
parent 8f1f0bf50f
commit 42571f6eb8
22 changed files with 149 additions and 137 deletions

View File

@@ -279,10 +279,6 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_resize_tokens_embeddings(self):
pass
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
def test_model_is_small(self):
pass
@require_torch
class EsmModelIntegrationTest(TestCasePlus):

View File

@@ -100,6 +100,28 @@ class EsmFoldModelTester:
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
esmfold_config = {
"trunk": {
"num_blocks": 2,
"sequence_state_dim": 64,
"pairwise_state_dim": 16,
"sequence_head_width": 4,
"pairwise_head_width": 4,
"position_bins": 4,
"chunk_size": 16,
"structure_module": {
"ipa_dim": 16,
"num_angles": 7,
"num_blocks": 2,
"num_heads_ipa": 4,
"pairwise_dim": 16,
"resnet_dim": 16,
"sequence_dim": 48,
},
},
"fp16_esm": False,
"lddt_head_hid_dim": 16,
}
config = EsmConfig(
vocab_size=33,
hidden_size=self.hidden_size,
@@ -114,7 +136,7 @@ class EsmFoldModelTester:
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
is_folding_model=True,
esmfold_config={"trunk": {"num_blocks": 2}, "fp16_esm": False},
esmfold_config=esmfold_config,
)
return config
@@ -126,8 +148,8 @@ class EsmFoldModelTester:
result = model(input_ids)
result = model(input_ids)
self.parent.assertEqual(result.positions.shape, (8, self.batch_size, self.seq_length, 14, 3))
self.parent.assertEqual(result.angles.shape, (8, self.batch_size, self.seq_length, 7, 2))
self.parent.assertEqual(result.positions.shape, (2, self.batch_size, self.seq_length, 14, 3))
self.parent.assertEqual(result.angles.shape, (2, self.batch_size, self.seq_length, 7, 2))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -243,10 +265,6 @@ class EsmFoldModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Will be fixed soon by reducing the size of the model used for common tests.")
def test_model_is_small(self):
pass
@require_torch
class EsmModelIntegrationTest(TestCasePlus):