GPT Neo few fixes (#10968)

* fix checkpoint names

* auto model

* fix doc
This commit is contained in:
Suraj Patil
2021-03-30 20:45:55 +05:30
committed by GitHub
parent 7772ddb473
commit 83d38c9ff3
7 changed files with 17 additions and 15 deletions

View File

@@ -432,7 +432,7 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
@slow
def test_batch_generation(self):
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
@@ -486,7 +486,7 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl", gradient_checkpointing=checkpointing)
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B", gradient_checkpointing=checkpointing)
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
@@ -497,8 +497,8 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_gpt_neo_sample(self):
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt_neo_xl")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt_neo_xl")
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
model.to(torch_device)
torch.manual_seed(0)