From eebc8abbe2e563fc334fd1dadfd31819fd1286b6 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 11 Feb 2019 14:04:19 +0100 Subject: [PATCH] clarify and unify model saving logic in examples --- README.md | 3 ++- examples/run_classifier.py | 40 +++++++++++++++++++++++++++----------- examples/run_squad.py | 19 +++++++++++------- examples/run_swag.py | 25 +++++++++++++++--------- 4 files changed, 59 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 4bce27db47..8efd64e5d2 100644 --- a/README.md +++ b/README.md @@ -779,7 +779,8 @@ python run_classifier.py \ --train_batch_size 32 \ --learning_rate 2e-5 \ --num_train_epochs 3.0 \ - --output_dir /tmp/mrpc_output/ + --output_dir /tmp/mrpc_output/ \ + --fp16 ``` #### SQuAD diff --git a/examples/run_classifier.py b/examples/run_classifier.py index a30d7982b0..83f0683a48 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -23,7 +23,6 @@ import logging import os import random import sys -from io import open import numpy as np import torch @@ -33,7 +32,7 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -from pytorch_pretrained_bert.modeling import BertForSequenceClassification +from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear @@ -92,7 +91,7 @@ class DataProcessor(object): @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" - with open(input_file, "rb") as f: + with open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: @@ -324,6 +323,10 @@ def main(): help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters + parser.add_argument("--cache_dir", + default="", + type=str, + help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--max_seq_length", default=128, type=int, @@ -383,9 +386,17 @@ def main(): help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") - + parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + processors = { "cola": ColaProcessor, "mnli": MnliProcessor, @@ -451,8 +462,9 @@ def main(): num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare model + cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)) model = BertForSequenceClassification.from_pretrained(args.bert_model, - cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)), + cache_dir=cache_dir, num_labels = num_labels) if args.fp16: model.half() @@ -549,15 +561,21 @@ def main(): optimizer.zero_grad() global_step += 1 - # Save a trained model - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if args.do_train: + # Save a trained model and the associated configuration + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) torch.save(model_to_save.state_dict(), output_model_file) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + with open(output_config_file, 'w') as f: + f.write(model_to_save.config.to_json_string()) - # Load a trained model that you have fine-tuned - model_state_dict = torch.load(output_model_file) - model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) + # Load a trained model and config that you have fine-tuned + config = BertConfig(output_config_file) + model = BertForSequenceClassification(config, num_labels=num_labels) + model.load_state_dict(torch.load(output_model_file)) + else: + model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) model.to(device) if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): diff --git a/examples/run_squad.py b/examples/run_squad.py index 9c2035701d..1d7c49c326 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -35,7 +35,7 @@ from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -from pytorch_pretrained_bert.modeling import BertForQuestionAnswering +from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.tokenization import (BasicTokenizer, BertTokenizer, @@ -1001,14 +1001,19 @@ def main(): optimizer.zero_grad() global_step += 1 - # Save a trained model - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if args.do_train: + # Save a trained model and the associated configuration + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) torch.save(model_to_save.state_dict(), output_model_file) - # Load a trained model that you have fine-tuned - model_state_dict = torch.load(output_model_file) - model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + with open(output_config_file, 'w') as f: + f.write(model_to_save.config.to_json_string()) + + # Load a trained model and config that you have fine-tuned + config = BertConfig(output_config_file) + model = BertForQuestionAnswering(config) + model.load_state_dict(torch.load(output_model_file)) else: model = BertForQuestionAnswering.from_pretrained(args.bert_model) diff --git a/examples/run_swag.py b/examples/run_swag.py index 52bcdcbd31..3ecea63046 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -469,18 +469,25 @@ def main(): optimizer.zero_grad() global_step += 1 - # Save a trained model - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") - torch.save(model_to_save.state_dict(), output_model_file) - # Load a trained model that you have fine-tuned - model_state_dict = torch.load(output_model_file) - model = BertForMultipleChoice.from_pretrained(args.bert_model, - state_dict=model_state_dict, - num_choices=4) + if args.do_train: + # Save a trained model and the associated configuration + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + torch.save(model_to_save.state_dict(), output_model_file) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + with open(output_config_file, 'w') as f: + f.write(model_to_save.config.to_json_string()) + + # Load a trained model and config that you have fine-tuned + config = BertConfig(output_config_file) + model = BertForMultipleChoice(config, num_choices=4) + model.load_state_dict(torch.load(output_model_file)) + else: + model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4) model.to(device) + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) eval_features = convert_examples_to_features(