Implement JSON dump conversion for torch_dtype in TrainingArguments (#31224)
* Implement JSON dump conversion for torch_dtype in TrainingArguments * Add unit test for converting torch_dtype in TrainingArguments to JSON * move unit test for converting torch_dtype into TrainerIntegrationTest class * reformating using ruff * convert dict_torch_dtype_to_str to private method _dict_torch_dtype_to_str --------- Co-authored-by: jun.4 <jun.4@kakaobrain.com>
This commit is contained in:
@@ -2370,6 +2370,18 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
return warmup_steps
|
return warmup_steps
|
||||||
|
|
||||||
|
def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
||||||
|
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
||||||
|
string, which can then be stored in the json format.
|
||||||
|
"""
|
||||||
|
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||||
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||||
|
for value in d.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self._dict_torch_dtype_to_str(value)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
@@ -2388,6 +2400,8 @@ class TrainingArguments:
|
|||||||
# Handle the accelerator_config if passed
|
# Handle the accelerator_config if passed
|
||||||
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
|
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
|
||||||
d[k] = v.to_dict()
|
d[k] = v.to_dict()
|
||||||
|
self._dict_torch_dtype_to_str(d)
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
|
|||||||
@@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
)
|
)
|
||||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||||
|
|
||||||
|
def test_torch_dtype_to_json(self):
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TorchDtypeTrainingArguments(TrainingArguments):
|
||||||
|
torch_dtype: torch.dtype = dataclasses.field(
|
||||||
|
default=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in [
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
"complex64",
|
||||||
|
"complex128",
|
||||||
|
"float16",
|
||||||
|
"bfloat16",
|
||||||
|
"uint8",
|
||||||
|
"int8",
|
||||||
|
"int16",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"bool",
|
||||||
|
]:
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
args_dict = args.to_dict()
|
||||||
|
self.assertIn("torch_dtype", args_dict)
|
||||||
|
self.assertEqual(args_dict["torch_dtype"], dtype)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user