update example to work with new serialization semantic

This commit is contained in:
thomwolf
2019-04-15 14:33:23 +02:00
parent b3c6ee0ac1
commit 179a2c2ff6
4 changed files with 55 additions and 40 deletions

View File

@@ -37,7 +37,7 @@ from sklearn.metrics import matthews_corrcoef, f1_score
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
@@ -857,18 +857,21 @@ def main():
optimizer.zero_grad()
global_step += 1
# Save a trained model and the associated configuration
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
with open(output_config_file, 'w') as f:
f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForSequenceClassification(config, num_labels=num_labels)
model.load_state_dict(torch.load(output_model_file))
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device)