Fix typing order (#39467)

* fix type order

* change all Union[str, dict] to Union[dict, str]

* add hf_parser test && fix test order

* add deepspeed dependency

* replace deepspeed with accelerator
This commit is contained in:
Qizhi Chen
2025-07-17 23:47:31 +08:00
committed by GitHub
parent bda75b4011
commit 73869f2e81
4 changed files with 32 additions and 25 deletions

View File

@@ -1395,7 +1395,7 @@ def _get_torch_dtype(
def _get_device_map( def _get_device_map(
model: "PreTrainedModel", model: "PreTrainedModel",
device_map: Optional[Union[str, dict]], device_map: Optional[Union[dict, str]],
max_memory: Optional[dict], max_memory: Optional[dict],
hf_quantizer: Optional[HfQuantizer], hf_quantizer: Optional[HfQuantizer],
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
@@ -2273,7 +2273,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return model return model
@classmethod @classmethod
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]: def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]:
""" """
Checks that the requested attention implementation exists and tries to get the kernel from hub Checks that the requested attention implementation exists and tries to get the kernel from hub
if `attn_implementation` matches hf kernels pattern. if `attn_implementation` matches hf kernels pattern.
@@ -2321,7 +2321,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return attn_implementation return attn_implementation
def set_attention_implementation(self, attn_implementation: Union[str, dict]): def set_attention_implementation(self, attn_implementation: Union[dict, str]):
""" """
Checks and dispatches to the requested attention implementation. Checks and dispatches to the requested attention implementation.
""" """

View File

@@ -611,7 +611,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def set_attention_implementation(self, attn_implementation: Union[str, dict]): def set_attention_implementation(self, attn_implementation: Union[dict, str]):
""" """
Checks and dispatches to hhe requested attention implementation. Checks and dispatches to hhe requested attention implementation.
""" """

View File

@@ -811,7 +811,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def set_attention_implementation(self, attn_implementation: Union[str, dict]): def set_attention_implementation(self, attn_implementation: Union[dict, str]):
""" """
Checks and dispatches to hhe requested attention implementation. Checks and dispatches to hhe requested attention implementation.
""" """

View File

@@ -23,6 +23,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union, get_args, get_origin from typing import Literal, Optional, Union, get_args, get_origin
from unittest.mock import patch
import yaml import yaml
@@ -160,7 +161,7 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertEqual(xx, yy) self.assertEqual(xx, yy)
def test_basic(self): def test_00_basic(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -174,7 +175,7 @@ class HfArgumentParserTest(unittest.TestCase):
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False) (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
self.assertFalse(example.flag) self.assertFalse(example.flag)
def test_with_default(self): def test_01_with_default(self):
parser = HfArgumentParser(WithDefaultExample) parser = HfArgumentParser(WithDefaultExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -182,7 +183,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--baz", default="toto", type=str, help="help message") expected.add_argument("--baz", default="toto", type=str, help="help message")
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
def test_with_default_bool(self): def test_02_with_default_bool(self):
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?") expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?") expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
@@ -217,7 +218,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"]) args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False)) self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self): def test_03_with_enum(self):
parser = HfArgumentParser(MixedTypeEnumExample) parser = HfArgumentParser(MixedTypeEnumExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -244,7 +245,7 @@ class HfArgumentParserTest(unittest.TestCase):
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0] enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo) self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
def test_with_literal(self): def test_04_with_literal(self):
@dataclass @dataclass
class LiteralExample: class LiteralExample:
foo: Literal["titi", "toto", 42] = "toto" foo: Literal["titi", "toto", 42] = "toto"
@@ -269,7 +270,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "42"]) args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42) self.assertEqual(args.foo, 42)
def test_with_list(self): def test_05_with_list(self):
parser = HfArgumentParser(ListExample) parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -292,7 +293,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split()) args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7])) self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_with_optional(self): def test_06_with_optional(self):
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int) expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message") expected.add_argument("--bar", default=None, type=float, help="help message")
@@ -315,7 +316,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split()) args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3])) self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_with_required(self): def test_07_with_required(self):
parser = HfArgumentParser(RequiredExample) parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -330,7 +331,7 @@ class HfArgumentParserTest(unittest.TestCase):
) )
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
def test_with_string_literal_annotation(self): def test_08_with_string_literal_annotation(self):
parser = HfArgumentParser(StringLiteralAnnotationExample) parser = HfArgumentParser(StringLiteralAnnotationExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
@@ -347,7 +348,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
def test_parse_dict(self): def test_09_parse_dict(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
args_dict = { args_dict = {
@@ -361,7 +362,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict) args = BasicExample(**args_dict)
self.assertEqual(parsed_args, args) self.assertEqual(parsed_args, args)
def test_parse_dict_extra_key(self): def test_10_parse_dict_extra_key(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
args_dict = { args_dict = {
@@ -374,7 +375,7 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False) self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_parse_json(self): def test_11_parse_json(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
args_dict_for_json = { args_dict_for_json = {
@@ -393,7 +394,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict_for_json) args = BasicExample(**args_dict_for_json)
self.assertEqual(parsed_args, args) self.assertEqual(parsed_args, args)
def test_parse_yaml(self): def test_12_parse_yaml(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
args_dict_for_yaml = { args_dict_for_yaml = {
@@ -411,12 +412,7 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict_for_yaml) args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args) self.assertEqual(parsed_args, args)
def test_z_integration_training_args(self): def test_13_valid_dict_annotation(self):
# make sure that this test executes last in the test suite
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
def test_valid_dict_annotation(self):
""" """
Tests to make sure that `dict` based annotations Tests to make sure that `dict` based annotations
are correctly made in the `TrainingArguments`. are correctly made in the `TrainingArguments`.
@@ -475,7 +471,7 @@ class HfArgumentParserTest(unittest.TestCase):
) )
@require_torch @require_torch
def test_valid_dict_input_parsing(self): def test_14_valid_dict_input_parsing(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments( args = TrainingArguments(
output_dir=tmp_dir, output_dir=tmp_dir,
@@ -483,3 +479,14 @@ class HfArgumentParserTest(unittest.TestCase):
) )
self.assertEqual(args.accelerator_config.split_batches, True) self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2) self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
def test_15_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
@require_torch
@patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
def test_16_cli_input_parsing(self):
parser = HfArgumentParser(TrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
self.assertEqual(training_args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)