spliting config and weight files for bert also
This commit is contained in:
19
README.md
19
README.md
@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher):
|
|||||||
{"exact_match": 84.65468306527909, "f1": 91.238669287002}
|
{"exact_match": 84.65468306527909, "f1": 91.238669287002}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Here is an example with the recent `bert-large-uncased-whole-word-masking`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 \
|
||||||
|
run_squad.py \
|
||||||
|
--bert_model bert-large-uncased-whole-word-masking \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--do_lower_case \
|
||||||
|
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||||
|
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||||
|
--train_batch_size 12 \
|
||||||
|
--learning_rate 3e-5 \
|
||||||
|
--num_train_epochs 2.0 \
|
||||||
|
--max_seq_length 384 \
|
||||||
|
--doc_stride 128 \
|
||||||
|
--output_dir /tmp/debug_squad/
|
||||||
|
```
|
||||||
|
|
||||||
## Notebooks
|
## Notebooks
|
||||||
|
|
||||||
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
||||||
|
|||||||
92
examples/bertology.py
Normal file
92
examples/bertology.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import BertModel, BertTokenizer
|
||||||
|
|
||||||
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
level = logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def run_model():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased',
|
||||||
|
help='pretrained model name or path to local checkpoint')
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=-1)
|
||||||
|
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
if args.batch_size == -1:
|
||||||
|
args.batch_size = 1
|
||||||
|
assert args.nsamples % args.batch_size == 0
|
||||||
|
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.random.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed(args.seed)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if args.length == -1:
|
||||||
|
args.length = model.config.n_ctx // 2
|
||||||
|
elif args.length > model.config.n_ctx:
|
||||||
|
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
context_tokens = []
|
||||||
|
if not args.unconditional:
|
||||||
|
raw_text = input("Model prompt >>> ")
|
||||||
|
while not raw_text:
|
||||||
|
print('Prompt should not be empty!')
|
||||||
|
raw_text = input("Model prompt >>> ")
|
||||||
|
context_tokens = enc.encode(raw_text)
|
||||||
|
generated = 0
|
||||||
|
for _ in range(args.nsamples // args.batch_size):
|
||||||
|
out = sample_sequence(
|
||||||
|
model=model, length=args.length,
|
||||||
|
context=context_tokens,
|
||||||
|
start_token=None,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
temperature=args.temperature, top_k=args.top_k, device=device
|
||||||
|
)
|
||||||
|
out = out[:, len(context_tokens):].tolist()
|
||||||
|
for i in range(args.batch_size):
|
||||||
|
generated += 1
|
||||||
|
text = enc.decode(out[i])
|
||||||
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
|
print(text)
|
||||||
|
print("=" * 80)
|
||||||
|
else:
|
||||||
|
generated = 0
|
||||||
|
for _ in range(args.nsamples // args.batch_size):
|
||||||
|
out = sample_sequence(
|
||||||
|
model=model, length=args.length,
|
||||||
|
context=None,
|
||||||
|
start_token=enc.encoder['<|endoftext|>'],
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
temperature=args.temperature, top_k=args.top_k, device=device
|
||||||
|
)
|
||||||
|
out = out[:,1:].tolist()
|
||||||
|
for i in range(args.batch_size):
|
||||||
|
generated += 1
|
||||||
|
text = enc.decode(out[i])
|
||||||
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
|
print(text)
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_model()
|
||||||
|
|
||||||
|
|
||||||
@@ -22,9 +22,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
||||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
||||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
||||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
||||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
||||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz",
|
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
||||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz",
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz",
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||||
|
}
|
||||||
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||||
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
||||||
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
||||||
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
||||||
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
||||||
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
||||||
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
||||||
|
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
||||||
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
||||||
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
||||||
}
|
}
|
||||||
BERT_CONFIG_NAME = 'bert_config.json'
|
BERT_CONFIG_NAME = 'bert_config.json'
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||||
archive_file))
|
archive_file))
|
||||||
return None
|
return None
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file and resolved_config_file == config_file:
|
||||||
logger.info("loading archive file {}".format(archive_file))
|
logger.info("loading weights file {}".format(archive_file))
|
||||||
|
logger.info("loading configuration file {}".format(config_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading archive file {} from cache at {}".format(
|
logger.info("loading weights file {} from cache at {}".format(
|
||||||
archive_file, resolved_archive_file))
|
archive_file, resolved_archive_file))
|
||||||
tempdir = None
|
logger.info("loading configuration file {} from cache at {}".format(
|
||||||
if os.path.isdir(resolved_archive_file) or from_tf:
|
config_file, resolved_config_file))
|
||||||
serialization_dir = resolved_archive_file
|
### Switching to split config/weight files configuration
|
||||||
else:
|
# tempdir = None
|
||||||
# Extract archive to temp dir
|
# if os.path.isdir(resolved_archive_file) or from_tf:
|
||||||
tempdir = tempfile.mkdtemp()
|
# serialization_dir = resolved_archive_file
|
||||||
logger.info("extracting archive file {} to temp dir {}".format(
|
# else:
|
||||||
resolved_archive_file, tempdir))
|
# # Extract archive to temp dir
|
||||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
# tempdir = tempfile.mkdtemp()
|
||||||
archive.extractall(tempdir)
|
# logger.info("extracting archive file {} to temp dir {}".format(
|
||||||
serialization_dir = tempdir
|
# resolved_archive_file, tempdir))
|
||||||
|
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||||
|
# archive.extractall(tempdir)
|
||||||
|
# serialization_dir = tempdir
|
||||||
# Load config
|
# Load config
|
||||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path, map_location='cpu')
|
state_dict = torch.load(weights_path, map_location='cpu')
|
||||||
if tempdir:
|
# if tempdir:
|
||||||
# Clean up temp dir
|
# # Clean up temp dir
|
||||||
shutil.rmtree(tempdir)
|
# shutil.rmtree(tempdir)
|
||||||
if from_tf:
|
if from_tf:
|
||||||
# Directly load from a TensorFlow checkpoint
|
# Directly load from a TensorFlow checkpoint
|
||||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|||||||
Reference in New Issue
Block a user