Implement JSON dump conversion for torch_dtype in TrainingArguments (#31224)

* Implement JSON dump conversion for torch_dtype in TrainingArguments

* Add unit test for converting torch_dtype in TrainingArguments to JSON

* move unit test for converting torch_dtype into TrainerIntegrationTest class

* reformating using ruff

* convert dict_torch_dtype_to_str to private method _dict_torch_dtype_to_str

---------

Co-authored-by: jun.4 <jun.4@kakaobrain.com>
This commit is contained in:
조준래
2024-06-07 23:43:34 +09:00
committed by GitHub
parent ff689f57aa
commit 60861fe1fd
2 changed files with 43 additions and 0 deletions

View File

@@ -3445,6 +3445,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
def test_torch_dtype_to_json(self):
@dataclasses.dataclass
class TorchDtypeTrainingArguments(TrainingArguments):
torch_dtype: torch.dtype = dataclasses.field(
default=torch.float32,
)
for dtype in [
"float32",
"float64",
"complex64",
"complex128",
"float16",
"bfloat16",
"uint8",
"int8",
"int16",
"int32",
"int64",
"bool",
]:
torch_dtype = getattr(torch, dtype)
with tempfile.TemporaryDirectory() as tmp_dir:
args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype)
args_dict = args.to_dict()
self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype)
@require_torch
@is_staging_test