Fix typing order (#39467)
* fix type order * change all Union[str, dict] to Union[dict, str] * add hf_parser test && fix test order * add deepspeed dependency * replace deepspeed with accelerator
This commit is contained in:
@@ -23,6 +23,7 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union, get_args, get_origin
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -160,7 +161,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(xx, yy)
|
||||
|
||||
def test_basic(self):
|
||||
def test_00_basic(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -174,7 +175,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
|
||||
self.assertFalse(example.flag)
|
||||
|
||||
def test_with_default(self):
|
||||
def test_01_with_default(self):
|
||||
parser = HfArgumentParser(WithDefaultExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -182,7 +183,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_default_bool(self):
|
||||
def test_02_with_default_bool(self):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
|
||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
@@ -217,7 +218,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
def test_03_with_enum(self):
|
||||
parser = HfArgumentParser(MixedTypeEnumExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -244,7 +245,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
|
||||
|
||||
def test_with_literal(self):
|
||||
def test_04_with_literal(self):
|
||||
@dataclass
|
||||
class LiteralExample:
|
||||
foo: Literal["titi", "toto", 42] = "toto"
|
||||
@@ -269,7 +270,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
|
||||
def test_with_list(self):
|
||||
def test_05_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -292,7 +293,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
def test_with_optional(self):
|
||||
def test_06_with_optional(self):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default=None, type=int)
|
||||
expected.add_argument("--bar", default=None, type=float, help="help message")
|
||||
@@ -315,7 +316,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||
|
||||
def test_with_required(self):
|
||||
def test_07_with_required(self):
|
||||
parser = HfArgumentParser(RequiredExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -330,7 +331,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_string_literal_annotation(self):
|
||||
def test_08_with_string_literal_annotation(self):
|
||||
parser = HfArgumentParser(StringLiteralAnnotationExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
@@ -347,7 +348,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_parse_dict(self):
|
||||
def test_09_parse_dict(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict = {
|
||||
@@ -361,7 +362,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = BasicExample(**args_dict)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_parse_dict_extra_key(self):
|
||||
def test_10_parse_dict_extra_key(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict = {
|
||||
@@ -374,7 +375,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||
|
||||
def test_parse_json(self):
|
||||
def test_11_parse_json(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_json = {
|
||||
@@ -393,7 +394,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = BasicExample(**args_dict_for_json)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_parse_yaml(self):
|
||||
def test_12_parse_yaml(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_yaml = {
|
||||
@@ -411,12 +412,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = BasicExample(**args_dict_for_yaml)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_z_integration_training_args(self):
|
||||
# make sure that this test executes last in the test suite
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
def test_valid_dict_annotation(self):
|
||||
def test_13_valid_dict_annotation(self):
|
||||
"""
|
||||
Tests to make sure that `dict` based annotations
|
||||
are correctly made in the `TrainingArguments`.
|
||||
@@ -475,7 +471,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_valid_dict_input_parsing(self):
|
||||
def test_14_valid_dict_input_parsing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
@@ -483,3 +479,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(args.accelerator_config.split_batches, True)
|
||||
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||
|
||||
def test_15_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
@require_torch
|
||||
@patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
|
||||
def test_16_cli_input_parsing(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
self.assertEqual(training_args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||
|
||||
Reference in New Issue
Block a user