Adding gradient checkpointing to GPT2 (#7446)
* GPT2 gradient checkpointing * find_unused_parameters removed if checkpointing * find_unused_parameters removed if checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Added a test for generation with checkpointing * Update src/transformers/configuration_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -88,7 +88,7 @@ class GPT2ModelTester:
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
@@ -127,6 +127,7 @@ class GPT2ModelTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@@ -269,6 +270,15 @@ class GPT2ModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPT2LMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_double_lm_head_model(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
):
|
||||
@@ -355,6 +365,10 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt2_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
@@ -366,33 +380,34 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
for checkpointing in [True, False]:
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
|
||||
Reference in New Issue
Block a user