updating weights loading
This commit is contained in:
@@ -653,8 +653,13 @@ class BertPreTrainedModel(nn.Module):
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
|
||||
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
@@ -708,24 +713,24 @@ class BertPreTrainedModel(nn.Module):
|
||||
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
# archive.extractall(tempdir)
|
||||
# serialization_dir = tempdir
|
||||
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
# if not os.path.exists(config_file):
|
||||
# # Backward compatibility with old naming format
|
||||
# config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
if not os.path.exists(config_file):
|
||||
# Backward compatibility with old naming format
|
||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
config = BertConfig.from_json_file(resolved_config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path, map_location='cpu')
|
||||
# weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
# if tempdir:
|
||||
# # Clean up temp dir
|
||||
# shutil.rmtree(tempdir)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||
# weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||
return load_tf_weights_in_bert(model, weights_path)
|
||||
# Load from a PyTorch state_dict
|
||||
old_keys = []
|
||||
|
||||
Reference in New Issue
Block a user