fix typos/bugs
This commit is contained in:
@@ -130,7 +130,7 @@ def gpt2LMHeadModel(*args, **kwargs):
|
|||||||
>>> predicted_token = tokenizer.decode([predicted_index])
|
>>> predicted_token = tokenizer.decode([predicted_index])
|
||||||
>>> assert predicted_token == ' who'
|
>>> assert predicted_token == ' who'
|
||||||
"""
|
"""
|
||||||
model = OpenAIGPTLMHeadModel.from_pretrained(*args, **kwargs)
|
model = GPT2LMHeadModel.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -148,9 +148,9 @@ def gpt2DoubleHeadsModel(*args, **kwargs):
|
|||||||
|
|
||||||
# Prepare tokenized input
|
# Prepare tokenized input
|
||||||
>>> text = "Who was Jim Henson ?"
|
>>> text = "Who was Jim Henson ?"
|
||||||
>>> indexed_tokens = tokenizer.encode(tokenized_text)
|
>>> indexed_tokens = tokenizer.encode(text)
|
||||||
>>> tokens_tensor = torch.tensor([indexed_tokens])
|
>>> tokens_tensor = torch.tensor([indexed_tokens])
|
||||||
>>> mc_token_ids = torch.LongTensor([ [len(tokenized_text)] ])
|
>>> mc_token_ids = torch.LongTensor([ [len(indexed_tokens)] ])
|
||||||
|
|
||||||
# Load gpt2DoubleHeadsModel
|
# Load gpt2DoubleHeadsModel
|
||||||
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2DoubleHeadsModel', 'gpt2')
|
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2DoubleHeadsModel', 'gpt2')
|
||||||
|
|||||||
Reference in New Issue
Block a user