Merge pull request #489 from huggingface/tokenization_serialization
Better serialization for Tokenizers and Configuration classes - Also fix #466
This commit is contained in:
@@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
@@ -225,6 +223,11 @@ class OpenAIGPTConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, rf, nx):
|
||||
@@ -473,7 +476,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
|
||||
|
||||
Reference in New Issue
Block a user