unified tokenizer api and serialization + tests
This commit is contained in:
@@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||
@@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
class OpenAIGPTConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||
"""
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = OpenAIGPTConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user