update example to work with new serialization semantic
This commit is contained in:
@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfi
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
whitespace_tokenize)
|
||||
whitespace_tokenize, VOCAB_NAME)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@@ -1009,18 +1009,21 @@ def main():
|
||||
global_step += 1
|
||||
|
||||
if args.do_train:
|
||||
# Save a trained model and the associated configuration
|
||||
# Save a trained model, configuration and tokenizer
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
with open(output_config_file, 'w') as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
config = BertConfig(output_config_file)
|
||||
model = BertForQuestionAnswering(config)
|
||||
model.load_state_dict(torch.load(output_model_file))
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
else:
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user