unified tokenizer api and serialization + tests

This commit is contained in:
thomwolf
2019-07-09 10:25:18 +02:00
parent 3d5f291386
commit b19786985d
34 changed files with 824 additions and 755 deletions

View File

@@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
@@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
@@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`.
"""
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
@@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models.
"""
config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"