From 73869f2e81467db8422cbb4831cce9a7bdc85c4b Mon Sep 17 00:00:00 2001 From: Qizhi Chen Date: Thu, 17 Jul 2025 23:47:31 +0800 Subject: [PATCH] Fix typing order (#39467) * fix type order * change all Union[str, dict] to Union[dict, str] * add hf_parser test && fix test order * add deepspeed dependency * replace deepspeed with accelerator --- src/transformers/modeling_utils.py | 6 +-- .../models/modernbert/modeling_modernbert.py | 2 +- .../models/modernbert/modular_modernbert.py | 2 +- tests/utils/test_hf_argparser.py | 47 +++++++++++-------- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4672af7165..a97ba8511d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1395,7 +1395,7 @@ def _get_torch_dtype( def _get_device_map( model: "PreTrainedModel", - device_map: Optional[Union[str, dict]], + device_map: Optional[Union[dict, str]], max_memory: Optional[dict], hf_quantizer: Optional[HfQuantizer], torch_dtype: Optional[torch.dtype], @@ -2273,7 +2273,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return model @classmethod - def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]: + def _check_attn_implementation(cls, attn_implementation: Union[dict, str]) -> Union[dict, str]: """ Checks that the requested attention implementation exists and tries to get the kernel from hub if `attn_implementation` matches hf kernels pattern. @@ -2321,7 +2321,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return attn_implementation - def set_attention_implementation(self, attn_implementation: Union[str, dict]): + def set_attention_implementation(self, attn_implementation: Union[dict, str]): """ Checks and dispatches to the requested attention implementation. """ diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 9b26635835..a76a6fead7 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -611,7 +611,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def set_attention_implementation(self, attn_implementation: Union[str, dict]): + def set_attention_implementation(self, attn_implementation: Union[dict, str]): """ Checks and dispatches to hhe requested attention implementation. """ diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 3e4041bd8b..254b5d3163 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -811,7 +811,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def set_attention_implementation(self, attn_implementation: Union[str, dict]): + def set_attention_implementation(self, attn_implementation: Union[dict, str]): """ Checks and dispatches to hhe requested attention implementation. """ diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 773f244008..27ecb84306 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -23,6 +23,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Literal, Optional, Union, get_args, get_origin +from unittest.mock import patch import yaml @@ -160,7 +161,7 @@ class HfArgumentParserTest(unittest.TestCase): self.assertEqual(xx, yy) - def test_basic(self): + def test_00_basic(self): parser = HfArgumentParser(BasicExample) expected = argparse.ArgumentParser() @@ -174,7 +175,7 @@ class HfArgumentParserTest(unittest.TestCase): (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False) self.assertFalse(example.flag) - def test_with_default(self): + def test_01_with_default(self): parser = HfArgumentParser(WithDefaultExample) expected = argparse.ArgumentParser() @@ -182,7 +183,7 @@ class HfArgumentParserTest(unittest.TestCase): expected.add_argument("--baz", default="toto", type=str, help="help message") self.argparsersEqual(parser, expected) - def test_with_default_bool(self): + def test_02_with_default_bool(self): expected = argparse.ArgumentParser() expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?") expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?") @@ -217,7 +218,7 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"]) self.assertEqual(args, Namespace(foo=False, baz=False, opt=False)) - def test_with_enum(self): + def test_03_with_enum(self): parser = HfArgumentParser(MixedTypeEnumExample) expected = argparse.ArgumentParser() @@ -244,7 +245,7 @@ class HfArgumentParserTest(unittest.TestCase): enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0] self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo) - def test_with_literal(self): + def test_04_with_literal(self): @dataclass class LiteralExample: foo: Literal["titi", "toto", 42] = "toto" @@ -269,7 +270,7 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args(["--foo", "42"]) self.assertEqual(args.foo, 42) - def test_with_list(self): + def test_05_with_list(self): parser = HfArgumentParser(ListExample) expected = argparse.ArgumentParser() @@ -292,7 +293,7 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split()) self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7])) - def test_with_optional(self): + def test_06_with_optional(self): expected = argparse.ArgumentParser() expected.add_argument("--foo", default=None, type=int) expected.add_argument("--bar", default=None, type=float, help="help message") @@ -315,7 +316,7 @@ class HfArgumentParserTest(unittest.TestCase): args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split()) self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3])) - def test_with_required(self): + def test_07_with_required(self): parser = HfArgumentParser(RequiredExample) expected = argparse.ArgumentParser() @@ -330,7 +331,7 @@ class HfArgumentParserTest(unittest.TestCase): ) self.argparsersEqual(parser, expected) - def test_with_string_literal_annotation(self): + def test_08_with_string_literal_annotation(self): parser = HfArgumentParser(StringLiteralAnnotationExample) expected = argparse.ArgumentParser() @@ -347,7 +348,7 @@ class HfArgumentParserTest(unittest.TestCase): expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) self.argparsersEqual(parser, expected) - def test_parse_dict(self): + def test_09_parse_dict(self): parser = HfArgumentParser(BasicExample) args_dict = { @@ -361,7 +362,7 @@ class HfArgumentParserTest(unittest.TestCase): args = BasicExample(**args_dict) self.assertEqual(parsed_args, args) - def test_parse_dict_extra_key(self): + def test_10_parse_dict_extra_key(self): parser = HfArgumentParser(BasicExample) args_dict = { @@ -374,7 +375,7 @@ class HfArgumentParserTest(unittest.TestCase): self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False) - def test_parse_json(self): + def test_11_parse_json(self): parser = HfArgumentParser(BasicExample) args_dict_for_json = { @@ -393,7 +394,7 @@ class HfArgumentParserTest(unittest.TestCase): args = BasicExample(**args_dict_for_json) self.assertEqual(parsed_args, args) - def test_parse_yaml(self): + def test_12_parse_yaml(self): parser = HfArgumentParser(BasicExample) args_dict_for_yaml = { @@ -411,12 +412,7 @@ class HfArgumentParserTest(unittest.TestCase): args = BasicExample(**args_dict_for_yaml) self.assertEqual(parsed_args, args) - def test_z_integration_training_args(self): - # make sure that this test executes last in the test suite - parser = HfArgumentParser(TrainingArguments) - self.assertIsNotNone(parser) - - def test_valid_dict_annotation(self): + def test_13_valid_dict_annotation(self): """ Tests to make sure that `dict` based annotations are correctly made in the `TrainingArguments`. @@ -475,7 +471,7 @@ class HfArgumentParserTest(unittest.TestCase): ) @require_torch - def test_valid_dict_input_parsing(self): + def test_14_valid_dict_input_parsing(self): with tempfile.TemporaryDirectory() as tmp_dir: args = TrainingArguments( output_dir=tmp_dir, @@ -483,3 +479,14 @@ class HfArgumentParserTest(unittest.TestCase): ) self.assertEqual(args.accelerator_config.split_batches, True) self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2) + + def test_15_integration_training_args(self): + parser = HfArgumentParser(TrainingArguments) + self.assertIsNotNone(parser) + + @require_torch + @patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}']) + def test_16_cli_input_parsing(self): + parser = HfArgumentParser(TrainingArguments) + training_args = parser.parse_args_into_dataclasses()[0] + self.assertEqual(training_args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)