From 9d6abf9778c90441f16b534895538d5e061d410c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 25 Feb 2025 18:06:52 +0800 Subject: [PATCH] enable torchao quantization on CPU (#36146) * enable torchao quantization on CPU Signed-off-by: jiqing-feng * fix int4 Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * enable CPU torchao tests Signed-off-by: jiqing-feng * fix cuda tests Signed-off-by: jiqing-feng * fix cpu tests Signed-off-by: jiqing-feng * update tests Signed-off-by: jiqing-feng * fix style Signed-off-by: jiqing-feng * fix cuda tests Signed-off-by: jiqing-feng * fix torchao available Signed-off-by: jiqing-feng * fix torchao available Signed-off-by: jiqing-feng * fix torchao config cannot convert to json * fix docs Signed-off-by: jiqing-feng * rm to_dict to rebase Signed-off-by: jiqing-feng * limited torchao version for CPU Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * fix skip Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * Update src/transformers/testing_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix cpu test Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- docs/source/en/quantization/overview.md | 2 +- docs/source/en/quantization/torchao.md | 8 +- src/transformers/testing_utils.py | 13 ++ src/transformers/utils/quantization_config.py | 12 +- .../torchao_integration/test_torchao.py | 160 +++++++++++------- 5 files changed, 125 insertions(+), 70 deletions(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 94696e300a..61fc8bf322 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -59,7 +59,7 @@ Use the table below to help you decide which quantization method to use. | [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto | | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | -| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 5 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | +| [torchao](./torchao.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 5 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | | [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | | [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ | | [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | | diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 06017c3f3e..c8116bf8ea 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -22,9 +22,11 @@ pip install --upgrade torch torchao transformers By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type. + ## Manually Choose Quantization Types and Settings `torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed. +If you want to run the following codes on CPU even with GPU available, just change `device_map="cpu"` and `quantization_config = TorchAoConfig("int4_weight_only", group_size=128, layout=Int4CPULayout())` where `layout` comes from `from torchao.dtypes import Int4CPULayout` which is only available from torchao 0.8.0 and higher. Users can manually specify the quantization types and settings they want to use: @@ -40,7 +42,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=" tokenizer = AutoTokenizer.from_pretrained(model_name) input_text = "What are we having for dinner?" -input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") +input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device) # auto-compile the quantized model with `cache_implementation="static"` to get speedup output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") @@ -59,7 +61,7 @@ def benchmark_fn(func: Callable, *args, **kwargs) -> float: MAX_NEW_TOKENS = 1000 print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) -bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) +bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16) output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) @@ -122,7 +124,7 @@ quantized_model.save_pretrained(output_dir, safe_serialization=False) # load quantized model ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id -loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda") +loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="auto") # confirm the speedup diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1d575ad4a3..4e06b11456 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -45,6 +45,7 @@ from unittest.mock import patch import huggingface_hub.utils import urllib3 from huggingface_hub import delete_repo +from packaging import version from transformers import logging as transformers_logging @@ -963,6 +964,18 @@ def require_torchao(test_case): return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) +def require_torchao_version_greater_or_equal(torchao_version): + def decorator(test_case): + correct_torchao_version = is_torchao_available() and version.parse( + version.parse(importlib.metadata.version("torchao")).base_version + ) >= version.parse(torchao_version) + return unittest.skipUnless( + correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}." + )(test_case) + + return decorator + + def require_torch_tensorrt_fx(test_case): """Decorator marking a test that requires Torch-TensorRT FX""" return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 2ac53dc315..9d91cfa593 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1558,7 +1558,17 @@ class TorchAoConfig(QuantizationConfigMixin): def get_apply_tensor_subclass(self): _STR_TO_METHOD = self._get_torchao_quant_type_to_method() - return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) + quant_type_kwargs = self.quant_type_kwargs.copy() + if ( + not torch.cuda.is_available() + and is_torchao_available() + and self.quant_type == "int4_weight_only" + and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + ): + from torchao.dtypes import Int4CPULayout + + quant_type_kwargs["layout"] = Int4CPULayout() + return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs) def __repr__(self): config_dict = self.to_dict() diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 60694924cd..8e00450083 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -14,15 +14,18 @@ # limitations under the License. import gc +import importlib.metadata import tempfile import unittest +from packaging import version + from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from transformers.testing_utils import ( require_torch_gpu, require_torch_multi_gpu, require_torchao, - torch_device, + require_torchao_version_greater_or_equal, ) from transformers.utils import is_torch_available, is_torchao_available @@ -38,13 +41,17 @@ if is_torchao_available(): ) from torchao.quantization.autoquant import AQMixin + if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0"): + from torchao.dtypes import Int4CPULayout -def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): + +def check_torchao_int4_wo_quantized(test_module, qlayer): weight = qlayer.weight - test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) test_module.assertEqual(weight.quant_min, 0) test_module.assertEqual(weight.quant_max, 15) - test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout)) + test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) + layout = Int4CPULayout if weight.device.type == "cpu" else TensorCoreTiledLayout + test_module.assertTrue(isinstance(weight.tensor_impl._layout, layout)) def check_autoquantized(test_module, qlayer): @@ -60,8 +67,8 @@ def check_forward(test_module, model, batch_size=1, context_size=1024): test_module.assertEqual(out.shape[1], context_size) -@require_torch_gpu @require_torchao +@require_torchao_version_greater_or_equal("0.8.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): """ @@ -102,15 +109,19 @@ class TorchAoConfigTest(unittest.TestCase): quantization_config.to_json_string(use_diff=False) -@require_torch_gpu @require_torchao +@require_torchao_version_greater_or_equal("0.8.0") class TorchAoTest(unittest.TestCase): input_text = "What are we having for dinner?" max_new_tokens = 10 - EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" - model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + device = "cpu" + quant_scheme_kwargs = ( + {"group_size": 32, "layout": Int4CPULayout()} + if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + else {"group_size": 32} + ) def tearDown(self): gc.collect() @@ -121,20 +132,20 @@ class TorchAoTest(unittest.TestCase): """ Simple LLM model testing int4 weight only quantization """ - quant_config = TorchAoConfig("int4_weight_only", group_size=32) + quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs) # Note: we quantize the bfloat16 model on the fly to int4 quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, - device_map=torch_device, + device_map=self.device, quantization_config=quant_config, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) + check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) @@ -143,46 +154,51 @@ class TorchAoTest(unittest.TestCase): """ Testing the dtype of model will be modified to be bfloat16 for int4 weight only quantization """ - quant_config = TorchAoConfig("int4_weight_only", group_size=32) + quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs) # Note: we quantize the bfloat16 model on the fly to int4 quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=None, - device_map=torch_device, + device_map=self.device, quantization_config=quant_config, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - check_torchao_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) + check_torchao_int4_wo_quantized(self, quantized_model.model.layers[0].self_attn.v_proj) - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - @require_torch_multi_gpu - def test_int4wo_quant_multi_gpu(self): + def test_int8_dynamic_activation_int8_weight_quant(self): """ - Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs - set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + Simple LLM model testing int8_dynamic_activation_int8_weight """ + quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") - quant_config = TorchAoConfig("int4_weight_only", group_size=32) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=torch.bfloat16, - device_map="auto", + device_map=self.device, quantization_config=quant_config, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) - - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + +@require_torch_gpu +class TorchAoGPUTest(TorchAoTest): + device = "cuda" + quant_scheme_kwargs = {"group_size": 32} def test_int4wo_offload(self): """ @@ -228,32 +244,35 @@ class TorchAoTest(unittest.TestCase): ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) EXPECTED_OUTPUT = "What are we having for dinner?\n- 2. What is the temperature outside" self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) - def test_int8_dynamic_activation_int8_weight_quant(self): + @require_torch_multi_gpu + def test_int4wo_quant_multi_gpu(self): """ - Simple LLM model testing int8_dynamic_activation_int8_weight + Simple test that checks if the quantized model int4 wieght only is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS """ - quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") - # Note: we quantize the bfloat16 model on the fly to int4 + quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs) quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - device_map=torch_device, + torch_dtype=torch.bfloat16, + device_map="auto", quantization_config=quant_config, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) def test_autoquant(self): """ @@ -264,11 +283,11 @@ class TorchAoTest(unittest.TestCase): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, - device_map=torch_device, + device_map=self.device, quantization_config=quant_config, ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) - input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device) + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) output = quantized_model.generate( **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" ) @@ -283,8 +302,8 @@ class TorchAoTest(unittest.TestCase): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) -@require_torch_gpu @require_torchao +@require_torchao_version_greater_or_equal("0.8.0") class TorchAoSerializationTest(unittest.TestCase): input_text = "What are we having for dinner?" max_new_tokens = 10 @@ -292,8 +311,13 @@ class TorchAoSerializationTest(unittest.TestCase): # TODO: investigate why we don't have the same output as the original model for this test SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} - device = "cuda:0" + quant_scheme = "int4_weight_only" + quant_scheme_kwargs = ( + {"group_size": 32, "layout": Int4CPULayout()} + if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + else {"group_size": 32} + ) + device = "cpu" # called only once for all test in this class @classmethod @@ -325,9 +349,9 @@ class TorchAoSerializationTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) loaded_quantized_model = AutoModelForCausalLM.from_pretrained( - self.model_name, torch_dtype=torch.bfloat16, device_map=self.device + self.model_name, torch_dtype=torch.bfloat16, device_map=device ) - input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) @@ -336,46 +360,52 @@ class TorchAoSerializationTest(unittest.TestCase): self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) -class TorchAoSerializationW8A8Test(TorchAoSerializationTest): - quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT - device = "cuda:0" - - -class TorchAoSerializationW8Test(TorchAoSerializationTest): - quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT - device = "cuda:0" - - class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT - device = "cpu" - def test_serialization_expected_output_cuda(self): + @require_torch_gpu + def test_serialization_expected_output_on_cuda(self): """ Test if we can serialize on device (cpu) and load/infer the model on cuda """ - new_device = "cuda:0" - self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT - device = "cpu" - def test_serialization_expected_output_cuda(self): + @require_torch_gpu + def test_serialization_expected_output_on_cuda(self): """ Test if we can serialize on device (cpu) and load/infer the model on cuda """ - new_device = "cuda:0" - self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) + + +@require_torch_gpu +class TorchAoSerializationGPTTest(TorchAoSerializationTest): + quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} + device = "cuda:0" + + +@require_torch_gpu +class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest): + quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" + + +@require_torch_gpu +class TorchAoSerializationW8GPUTest(TorchAoSerializationTest): + quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" if __name__ == "__main__":