Implement JSON dump conversion for torch_dtype in TrainingArguments (#31224)
* Implement JSON dump conversion for torch_dtype in TrainingArguments * Add unit test for converting torch_dtype in TrainingArguments to JSON * move unit test for converting torch_dtype into TrainerIntegrationTest class * reformating using ruff * convert dict_torch_dtype_to_str to private method _dict_torch_dtype_to_str --------- Co-authored-by: jun.4 <jun.4@kakaobrain.com>
This commit is contained in:
@@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
|
||||
|
||||
def test_torch_dtype_to_json(self):
|
||||
@dataclasses.dataclass
|
||||
class TorchDtypeTrainingArguments(TrainingArguments):
|
||||
torch_dtype: torch.dtype = dataclasses.field(
|
||||
default=torch.float32,
|
||||
)
|
||||
|
||||
for dtype in [
|
||||
"float32",
|
||||
"float64",
|
||||
"complex64",
|
||||
"complex128",
|
||||
"float16",
|
||||
"bfloat16",
|
||||
"uint8",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"bool",
|
||||
]:
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype)
|
||||
|
||||
args_dict = args.to_dict()
|
||||
self.assertIn("torch_dtype", args_dict)
|
||||
self.assertEqual(args_dict["torch_dtype"], dtype)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user