clarify and unify model saving logic in examples

This commit is contained in:
thomwolf
2019-02-11 14:04:19 +01:00
parent 81c7e3ec9f
commit eebc8abbe2
4 changed files with 59 additions and 28 deletions

View File

@@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer,
@@ -1001,14 +1001,19 @@ def main():
optimizer.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model)