update run_openai_gpt to fix #1264
This commit is contained in:
@@ -153,9 +153,11 @@ def main():
|
|||||||
# This loading functions also add new tokens and embeddings called `special tokens`
|
# This loading functions also add new tokens and embeddings called `special tokens`
|
||||||
# These new embeddings will be fine-tuned on the RocStories dataset
|
# These new embeddings will be fine-tuned on the RocStories dataset
|
||||||
special_tokens = ['_start_', '_delimiter_', '_classify_']
|
special_tokens = ['_start_', '_delimiter_', '_classify_']
|
||||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens)
|
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
|
||||||
special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
|
tokenizer.add_tokens(special_tokens)
|
||||||
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens))
|
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
|
||||||
|
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name)
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# Load and encode the datasets
|
# Load and encode the datasets
|
||||||
|
|||||||
Reference in New Issue
Block a user