updating weights loading
This commit is contained in:
@@ -653,8 +653,13 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
if from_tf:
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
# Directly load from a TensorFlow checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
@@ -708,24 +713,24 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||||
# archive.extractall(tempdir)
|
# archive.extractall(tempdir)
|
||||||
# serialization_dir = tempdir
|
# serialization_dir = tempdir
|
||||||
|
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||||
|
# if not os.path.exists(config_file):
|
||||||
|
# # Backward compatibility with old naming format
|
||||||
|
# config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||||
# Load config
|
# Load config
|
||||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
config = BertConfig.from_json_file(resolved_config_file)
|
||||||
if not os.path.exists(config_file):
|
|
||||||
# Backward compatibility with old naming format
|
|
||||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
|
||||||
config = BertConfig.from_json_file(config_file)
|
|
||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
# weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path, map_location='cpu')
|
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||||
# if tempdir:
|
# if tempdir:
|
||||||
# # Clean up temp dir
|
# # Clean up temp dir
|
||||||
# shutil.rmtree(tempdir)
|
# shutil.rmtree(tempdir)
|
||||||
if from_tf:
|
if from_tf:
|
||||||
# Directly load from a TensorFlow checkpoint
|
# Directly load from a TensorFlow checkpoint
|
||||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
# weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||||
return load_tf_weights_in_bert(model, weights_path)
|
return load_tf_weights_in_bert(model, weights_path)
|
||||||
# Load from a PyTorch state_dict
|
# Load from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
|
|||||||
Reference in New Issue
Block a user