Added tests for yaml and json parser (#19219)
* Added tests for yaml and json * Added tests for yaml and json
This commit is contained in:
@@ -281,7 +281,9 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
- the dataclass instances in the same order as they were passed to the initializer.
|
- the dataclass instances in the same order as they were passed to the initializer.
|
||||||
"""
|
"""
|
||||||
outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys)
|
open_json_file = open(Path(json_file))
|
||||||
|
data = json.loads(open_json_file.read())
|
||||||
|
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
|
||||||
return tuple(outputs)
|
return tuple(outputs)
|
||||||
|
|
||||||
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
|
||||||
@@ -301,5 +303,5 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
|
|
||||||
- the dataclass instances in the same order as they were passed to the initializer.
|
- the dataclass instances in the same order as they were passed to the initializer.
|
||||||
"""
|
"""
|
||||||
outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys)
|
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
|
||||||
return tuple(outputs)
|
return tuple(outputs)
|
||||||
|
|||||||
@@ -13,12 +13,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
from transformers import HfArgumentParser, TrainingArguments
|
from transformers import HfArgumentParser, TrainingArguments
|
||||||
from transformers.hf_argparser import string_to_bool
|
from transformers.hf_argparser import string_to_bool
|
||||||
|
|
||||||
@@ -258,6 +263,43 @@ class HfArgumentParserTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||||
|
|
||||||
|
def test_parse_json(self):
|
||||||
|
parser = HfArgumentParser(BasicExample)
|
||||||
|
|
||||||
|
args_dict_for_json = {
|
||||||
|
"foo": 12,
|
||||||
|
"bar": 3.14,
|
||||||
|
"baz": "42",
|
||||||
|
"flag": True,
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
temp_local_path = os.path.join(tmp_dir, "temp_json")
|
||||||
|
os.mkdir(temp_local_path)
|
||||||
|
with open(temp_local_path + ".json", "w+") as f:
|
||||||
|
json.dump(args_dict_for_json, f)
|
||||||
|
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".json"))[0]
|
||||||
|
|
||||||
|
args = BasicExample(**args_dict_for_json)
|
||||||
|
self.assertEqual(parsed_args, args)
|
||||||
|
|
||||||
|
def test_parse_yaml(self):
|
||||||
|
parser = HfArgumentParser(BasicExample)
|
||||||
|
|
||||||
|
args_dict_for_yaml = {
|
||||||
|
"foo": 12,
|
||||||
|
"bar": 3.14,
|
||||||
|
"baz": "42",
|
||||||
|
"flag": True,
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
|
||||||
|
os.mkdir(temp_local_path)
|
||||||
|
with open(temp_local_path + ".yaml", "w+") as f:
|
||||||
|
yaml.dump(args_dict_for_yaml, f)
|
||||||
|
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
|
||||||
|
args = BasicExample(**args_dict_for_yaml)
|
||||||
|
self.assertEqual(parsed_args, args)
|
||||||
|
|
||||||
def test_integration_training_args(self):
|
def test_integration_training_args(self):
|
||||||
parser = HfArgumentParser(TrainingArguments)
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
self.assertIsNotNone(parser)
|
self.assertIsNotNone(parser)
|
||||||
|
|||||||
Reference in New Issue
Block a user