added support for gradient checkpointing in ESM models (#26386)

This commit is contained in:
sanjeevk-os
2023-09-26 18:15:53 +10:00
committed by GitHub
parent a8531f3bfd
commit 6ce6a5adb9
2 changed files with 27 additions and 6 deletions

View File

@@ -151,6 +151,24 @@ class EsmModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards(
self,
config,
input_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
gradient_checkpointing=False,
):
model = EsmForMaskedLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_esm_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_model_from_pretrained(self):
for model_name in ESM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: