Merge pull request #2069 from huggingface/cleaner-pt-tf-conversion
clean up PT <=> TF conversion
This commit is contained in:
@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||||
|
|
||||||
pt_model = pt_model_class.from_pretrained(None,
|
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
|
||||||
|
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
|
||||||
config=config,
|
config=config,
|
||||||
state_dict=torch.load(pytorch_checkpoint_path,
|
state_dict=state_dict)
|
||||||
map_location='cpu'))
|
|
||||||
pt_inputs = torch.tensor(inputs_list)
|
pt_inputs = torch.tensor(inputs_list)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(pt_inputs)
|
pto = pt_model(pt_inputs)
|
||||||
@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
|
|
||||||
|
|
||||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
||||||
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
|
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
|
||||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||||
|
|
||||||
if args_model_type is None:
|
if args_model_type is None:
|
||||||
@@ -187,13 +188,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
|
|
||||||
if os.path.isfile(model_shortcut_name):
|
if os.path.isfile(model_shortcut_name):
|
||||||
model_shortcut_name = 'converted_model'
|
model_shortcut_name = 'converted_model'
|
||||||
|
|
||||||
convert_pt_checkpoint_to_tf(model_type=model_type,
|
convert_pt_checkpoint_to_tf(model_type=model_type,
|
||||||
pytorch_checkpoint_path=model_file,
|
pytorch_checkpoint_path=model_file,
|
||||||
config_file=config_file,
|
config_file=config_file,
|
||||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
||||||
compare_with_pt_model=compare_with_pt_model)
|
compare_with_pt_model=compare_with_pt_model)
|
||||||
os.remove(config_file)
|
if remove_cached_files:
|
||||||
os.remove(model_file)
|
os.remove(config_file)
|
||||||
|
os.remove(model_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -226,6 +229,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--use_cached_models",
|
parser.add_argument("--use_cached_models",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
||||||
|
parser.add_argument("--remove_cached_files",
|
||||||
|
action='store_true',
|
||||||
|
help = "Remove pytorch models after conversion (save memory when converting in batches).")
|
||||||
parser.add_argument("--only_convert_finetuned_models",
|
parser.add_argument("--only_convert_finetuned_models",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help = "Only convert finetuned models.")
|
help = "Only convert finetuned models.")
|
||||||
@@ -245,4 +251,5 @@ if __name__ == "__main__":
|
|||||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||||
compare_with_pt_model=args.compare_with_pt_model,
|
compare_with_pt_model=args.compare_with_pt_model,
|
||||||
use_cached_models=args.use_cached_models,
|
use_cached_models=args.use_cached_models,
|
||||||
|
remove_cached_files=args.remove_cached_files,
|
||||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
||||||
|
|||||||
@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
|
if pretrained_model_name_or_path is not None and (
|
||||||
|
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
|
||||||
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
|
||||||
"https://github.com/google-research/google-research/issues/119 for more information.")
|
"https://github.com/google-research/google-research/issues/119 for more information.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user