diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index b74db2ee4e..06a10ff5a0 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -281,7 +281,9 @@ class HfArgumentParser(ArgumentParser): - the dataclass instances in the same order as they were passed to the initializer. """ - outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys) + open_json_file = open(Path(json_file)) + data = json.loads(open_json_file.read()) + outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) return tuple(outputs) def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: @@ -301,5 +303,5 @@ class HfArgumentParser(ArgumentParser): - the dataclass instances in the same order as they were passed to the initializer. """ - outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys) + outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) return tuple(outputs) diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 827888509b..9dfff75948 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -13,12 +13,17 @@ # limitations under the License. import argparse +import json +import os +import tempfile import unittest from argparse import Namespace from dataclasses import dataclass, field from enum import Enum +from pathlib import Path from typing import List, Optional +import yaml from transformers import HfArgumentParser, TrainingArguments from transformers.hf_argparser import string_to_bool @@ -258,6 +263,43 @@ class HfArgumentParserTest(unittest.TestCase): self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False) + def test_parse_json(self): + parser = HfArgumentParser(BasicExample) + + args_dict_for_json = { + "foo": 12, + "bar": 3.14, + "baz": "42", + "flag": True, + } + with tempfile.TemporaryDirectory() as tmp_dir: + temp_local_path = os.path.join(tmp_dir, "temp_json") + os.mkdir(temp_local_path) + with open(temp_local_path + ".json", "w+") as f: + json.dump(args_dict_for_json, f) + parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".json"))[0] + + args = BasicExample(**args_dict_for_json) + self.assertEqual(parsed_args, args) + + def test_parse_yaml(self): + parser = HfArgumentParser(BasicExample) + + args_dict_for_yaml = { + "foo": 12, + "bar": 3.14, + "baz": "42", + "flag": True, + } + with tempfile.TemporaryDirectory() as tmp_dir: + temp_local_path = os.path.join(tmp_dir, "temp_yaml") + os.mkdir(temp_local_path) + with open(temp_local_path + ".yaml", "w+") as f: + yaml.dump(args_dict_for_yaml, f) + parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0] + args = BasicExample(**args_dict_for_yaml) + self.assertEqual(parsed_args, args) + def test_integration_training_args(self): parser = HfArgumentParser(TrainingArguments) self.assertIsNotNone(parser)