From 11792d7826854979bb532b6da09bc3796b09ea6a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 27 Jul 2020 12:21:25 -0400 Subject: [PATCH] CL util to convert models to fp16 before upload (#5953) --- examples/seq2seq/convert_model_to_fp16.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 examples/seq2seq/convert_model_to_fp16.py diff --git a/examples/seq2seq/convert_model_to_fp16.py b/examples/seq2seq/convert_model_to_fp16.py new file mode 100644 index 0000000000..24042cc0e7 --- /dev/null +++ b/examples/seq2seq/convert_model_to_fp16.py @@ -0,0 +1,21 @@ +from typing import Union + +import fire +import torch +from tqdm import tqdm + + +def convert(src_path: str, map_location: str = "cpu", save_path: Union[str, None] = None) -> None: + """Convert a pytorch_model.bin or model.pt file to torch.float16 for faster downloads, less disk space.""" + state_dict = torch.load(src_path, map_location=map_location) + for k, v in tqdm(state_dict.items()): + if not isinstance(v, torch.Tensor): + raise TypeError("FP16 conversion only works on paths that are saved state dics, like pytorch_model.bin") + state_dict[k] = v.half() + if save_path is None: # overwrite src_path + save_path = src_path + torch.save(state_dict, save_path) + + +if __name__ == "__main__": + fire.Fire(convert)