From 382e2d1e50e5ea95c4393b5758e0b0907f43e1c5 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 10:37:16 +0200 Subject: [PATCH] spliting config and weight files for bert also --- README.md | 19 ++++ examples/bertology.py | 92 +++++++++++++++++++ pytorch_pretrained_bert/modeling.py | 78 +++++++++------- pytorch_pretrained_bert/modeling_gpt2.py | 3 - pytorch_pretrained_bert/modeling_openai.py | 3 - .../modeling_transfo_xl.py | 3 - 6 files changed, 158 insertions(+), 40 deletions(-) create mode 100644 examples/bertology.py diff --git a/README.md b/README.md index 9ddeba808e..a596dc4c64 100644 --- a/README.md +++ b/README.md @@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher): {"exact_match": 84.65468306527909, "f1": 91.238669287002} ``` +Here is an example with the recent `bert-large-uncased-whole-word-masking`: + +```bash +python -m torch.distributed.launch --nproc_per_node=8 \ + run_squad.py \ + --bert_model bert-large-uncased-whole-word-masking \ + --do_train \ + --do_predict \ + --do_lower_case \ + --train_file $SQUAD_DIR/train-v1.1.json \ + --predict_file $SQUAD_DIR/dev-v1.1.json \ + --train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 2.0 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir /tmp/debug_squad/ +``` + ## Notebooks We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. diff --git a/examples/bertology.py b/examples/bertology.py new file mode 100644 index 0000000000..0ceac28ff9 --- /dev/null +++ b/examples/bertology.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +import argparse +import logging +from tqdm import trange + +import torch +import torch.nn.functional as F +import numpy as np + +from pytorch_pretrained_bert import BertModel, BertTokenizer + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +def run_model(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased', + help='pretrained model name or path to local checkpoint') + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--batch_size", type=int, default=-1) + parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') + args = parser.parse_args() + print(args) + + if args.batch_size == -1: + args.batch_size = 1 + assert args.nsamples % args.batch_size == 0 + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + model.to(device) + model.eval() + + if args.length == -1: + args.length = model.config.n_ctx // 2 + elif args.length > model.config.n_ctx: + raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) + + while True: + context_tokens = [] + if not args.unconditional: + raw_text = input("Model prompt >>> ") + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") + context_tokens = enc.encode(raw_text) + generated = 0 + for _ in range(args.nsamples // args.batch_size): + out = sample_sequence( + model=model, length=args.length, + context=context_tokens, + start_token=None, + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + out = out[:, len(context_tokens):].tolist() + for i in range(args.batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) + else: + generated = 0 + for _ in range(args.nsamples // args.batch_size): + out = sample_sequence( + model=model, length=args.length, + context=None, + start_token=enc.encoder['<|endoftext|>'], + batch_size=args.batch_size, + temperature=args.temperature, top_k=args.top_k, device=device + ) + out = out[:,1:].tolist() + for i in range(args.batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(text) + print("=" * 80) + +if __name__ == '__main__': + run_model() + + diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index fa2cba9aa7..006e6a1c73 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -22,9 +22,6 @@ import json import logging import math import os -import shutil -import tarfile -import tempfile import sys from io import open @@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", - 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz", + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", + 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", +} +PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", + 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", } BERT_CONFIG_NAME = 'bert_config.json' TF_WEIGHTS_NAME = 'model.ckpt' @@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module): if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] else: - archive_file = pretrained_model_name_or_path + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: logger.error( @@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module): ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) + if resolved_archive_file == archive_file and resolved_config_file == config_file: + logger.info("loading weights file {}".format(archive_file)) + logger.info("loading configuration file {}".format(config_file)) else: - logger.info("loading archive file {} from cache at {}".format( + logger.info("loading weights file {} from cache at {}".format( archive_file, resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: - archive.extractall(tempdir) - serialization_dir = tempdir + logger.info("loading configuration file {} from cache at {}".format( + config_file, resolved_config_file)) + ### Switching to split config/weight files configuration + # tempdir = None + # if os.path.isdir(resolved_archive_file) or from_tf: + # serialization_dir = resolved_archive_file + # else: + # # Extract archive to temp dir + # tempdir = tempfile.mkdtemp() + # logger.info("extracting archive file {} to temp dir {}".format( + # resolved_archive_file, tempdir)) + # with tarfile.open(resolved_archive_file, 'r:gz') as archive: + # archive.extractall(tempdir) + # serialization_dir = tempdir # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(config_file): @@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module): if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path, map_location='cpu') - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) + # if tempdir: + # # Clean up temp dir + # shutil.rmtree(tempdir) if from_tf: # Directly load from a TensorFlow checkpoint weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index e3ae47402e..caa9cf809c 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -23,9 +23,6 @@ import json import logging import math import os -import shutil -import tarfile -import tempfile import sys from io import open diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 257c2f985c..d525c96e77 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -23,9 +23,6 @@ import json import logging import math import os -import shutil -import tarfile -import tempfile import sys from io import open diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 12a1535ff6..b3e829670a 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -25,9 +25,6 @@ import copy import json import math import logging -import tarfile -import tempfile -import shutil import collections import sys from io import open