split saved model in config & weights
This commit is contained in:
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"}
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||||
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-openai_gpt_config.json"}
|
||||||
|
|
||||||
CONFIG_NAME = "openai_gpt_config.json"
|
CONFIG_NAME = "openai_gpt_config.json"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|
||||||
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||||
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name
|
archive_file = pretrained_model_name
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
"associated to this path or url.".format(
|
"at this path or url.".format(
|
||||||
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file
|
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||||
|
archive_file, config_file
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file and resolved_config_file == config_file:
|
||||||
logger.info("loading archive file {}".format(archive_file))
|
logger.info("loading weights file {}".format(archive_file))
|
||||||
|
logger.info("loading configuration file {}".format(config_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file))
|
logger.info("loading weights file {} from cache at {}".format(
|
||||||
tempdir = None
|
archive_file, resolved_archive_file))
|
||||||
if os.path.isdir(resolved_archive_file):
|
logger.info("loading configuration file {} from cache at {}".format(
|
||||||
serialization_dir = resolved_archive_file
|
config_file, resolved_config_file))
|
||||||
else:
|
|
||||||
# Extract archive to temp dir
|
|
||||||
tempdir = tempfile.mkdtemp()
|
|
||||||
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
|
|
||||||
with tarfile.open(resolved_archive_file, "r:gz") as archive:
|
|
||||||
archive.extractall(tempdir)
|
|
||||||
serialization_dir = tempdir
|
|
||||||
# Load config
|
# Load config
|
||||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
config = OpenAIGPTConfig.from_json_file(resolved_config_file)
|
||||||
config = OpenAIGPTConfig.from_json_file(config_file)
|
|
||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
|
||||||
if tempdir:
|
|
||||||
# Clean up temp dir
|
|
||||||
shutil.rmtree(tempdir)
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||||
return load_tf_weights_in_openai_gpt(model, serialization_dir)
|
return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
|
||||||
|
|
||||||
old_keys = []
|
old_keys = []
|
||||||
new_keys = []
|
new_keys = []
|
||||||
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add additional embeddings for special tokens if needed
|
# Add additional embeddings for special tokens if needed
|
||||||
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
||||||
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
|
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
|
||||||
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
" Update input and output embeddings with new embedding matrice "
|
""" Update input and output embeddings with new embedding matrice
|
||||||
|
Make sure we are sharing the embeddings
|
||||||
|
"""
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||||
|
|
||||||
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
" Update input and output embeddings with new embedding matrice "
|
""" Update input and output embeddings with new embedding matrice
|
||||||
|
Make sure we are sharing the embeddings
|
||||||
|
"""
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user