diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index efb05682aa..80b0740d20 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2370,6 +2370,18 @@ class TrainingArguments: ) return warmup_steps + def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self._dict_torch_dtype_to_str(value) + def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates @@ -2388,6 +2400,8 @@ class TrainingArguments: # Handle the accelerator_config if passed if is_accelerate_available() and isinstance(v, AcceleratorConfig): d[k] = v.to_dict() + self._dict_torch_dtype_to_str(d) + return d def to_json_string(self): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4d3fc57340..af456a9bda 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ) self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + def test_torch_dtype_to_json(self): + @dataclasses.dataclass + class TorchDtypeTrainingArguments(TrainingArguments): + torch_dtype: torch.dtype = dataclasses.field( + default=torch.float32, + ) + + for dtype in [ + "float32", + "float64", + "complex64", + "complex128", + "float16", + "bfloat16", + "uint8", + "int8", + "int16", + "int32", + "int64", + "bool", + ]: + torch_dtype = getattr(torch, dtype) + with tempfile.TemporaryDirectory() as tmp_dir: + args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype) + + args_dict = args.to_dict() + self.assertIn("torch_dtype", args_dict) + self.assertEqual(args_dict["torch_dtype"], dtype) + @require_torch @is_staging_test