diff --git a/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py b/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py index 454951ed38..0ee4dd8afe 100644 --- a/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py +++ b/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py @@ -418,7 +418,7 @@ class TestTheRest(TestCasePlus): with CaptureStdout() as cs: args = parser.parse_args(args) assert False, "--help is expected to sys.exit" - assert excinfo.type == SystemExit + assert excinfo.type is SystemExit expected = lightning_base.arg_to_scheduler_metavar assert expected in cs.out, "--help is expected to list the supported schedulers" @@ -429,7 +429,7 @@ class TestTheRest(TestCasePlus): with CaptureStderr() as cs: args = parser.parse_args(args) assert False, "invalid argument is expected to sys.exit" - assert excinfo.type == SystemExit + assert excinfo.type is SystemExit expected = f"invalid choice: '{unsupported_param}'" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" diff --git a/setup.py b/setup.py index f6a6875dcd..67f1cbfd80 100644 --- a/setup.py +++ b/setup.py @@ -157,7 +157,7 @@ _deps = [ "rhoknp>=1.1.0,<1.3.1", "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", - "ruff==0.4.4", + "ruff==0.5.1", "sacrebleu>=1.4.12,<2.0.0", "sacremoses", "safetensors>=0.4.1", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index fcbb8469b9..7644d8d68d 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -63,7 +63,7 @@ deps = { "rhoknp": "rhoknp>=1.1.0,<1.3.1", "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", - "ruff": "ruff==0.4.4", + "ruff": "ruff==0.5.1", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", "safetensors": "safetensors>=0.4.1", diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 045bf79805..4b5548fffb 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -164,7 +164,7 @@ class HfArgumentParser(ArgumentParser): ) if type(None) not in field.type.__args__: # filter `str` in Union - field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] + field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1] origin_type = getattr(field.type, "__origin__", field.type) elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 61077cf7c3..9d12e1e67c 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -90,7 +90,7 @@ def dtype_byte_size(dtype): 4 ``` """ - if dtype == bool: + if dtype is bool: return 1 / 8 bit_search = re.search(r"[^\d](\d+)$", dtype.name) if bit_search is None: diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 4d9173fd08..e80e3c41d2 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -398,7 +398,7 @@ class TransformerBlock(nn.Module): if output_attentions: sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples - if type(sa_output) != tuple: + if type(sa_output) is not tuple: raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") sa_output = sa_output[0] diff --git a/src/transformers/models/distilbert/modeling_flax_distilbert.py b/src/transformers/models/distilbert/modeling_flax_distilbert.py index d3c48c077a..0cb7cdb033 100644 --- a/src/transformers/models/distilbert/modeling_flax_distilbert.py +++ b/src/transformers/models/distilbert/modeling_flax_distilbert.py @@ -304,7 +304,7 @@ class FlaxTransformerBlock(nn.Module): if output_attentions: sa_output, sa_weights = sa_output else: - assert type(sa_output) == tuple + assert type(sa_output) is tuple sa_output = sa_output[0] sa_output = self.sa_layer_norm(sa_output + hidden_states) diff --git a/src/transformers/models/esm/openfold_utils/rigid_utils.py b/src/transformers/models/esm/openfold_utils/rigid_utils.py index 2bc2fe5f5c..08f5ce0a4f 100644 --- a/src/transformers/models/esm/openfold_utils/rigid_utils.py +++ b/src/transformers/models/esm/openfold_utils/rigid_utils.py @@ -343,7 +343,7 @@ class Rotation: Returns: The indexed rotation """ - if type(index) != tuple: + if type(index) is not tuple: index = (index,) if self._rot_mats is not None: @@ -827,7 +827,7 @@ class Rigid: Returns: The indexed tensor """ - if type(index) != tuple: + if type(index) is not tuple: index = (index,) return Rigid( diff --git a/src/transformers/models/markuplm/feature_extraction_markuplm.py b/src/transformers/models/markuplm/feature_extraction_markuplm.py index 73c16bad30..e3effdc910 100644 --- a/src/transformers/models/markuplm/feature_extraction_markuplm.py +++ b/src/transformers/models/markuplm/feature_extraction_markuplm.py @@ -68,7 +68,7 @@ class MarkupLMFeatureExtractor(FeatureExtractionMixin): for element in html_code.descendants: if isinstance(element, bs4.element.NavigableString): - if type(element.parent) != bs4.element.Tag: + if type(element.parent) is not bs4.element.Tag: continue text_in_this_tag = html.unescape(element).strip() diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 7aaaeb461c..b0e456db8a 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -2550,7 +2550,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) - if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple: + if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple: # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5c1ffd8516..69b547dec5 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -254,7 +254,7 @@ def reissue_pt_warnings(caught_warnings): # Reissue warnings that are not the SAVE_STATE_WARNING if len(caught_warnings) > 1: for w in caught_warnings: - if w.category != UserWarning or w.message != SAVE_STATE_WARNING: + if w.category is not UserWarning or w.message != SAVE_STATE_WARNING: warnings.warn(w.message, w.category) diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index f47d0b0c35..6dac8b8520 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -198,7 +198,7 @@ Action: ) agent.run("What is 2 multiplied by 3.6452?") assert len(agent.logs) == 7 - assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError + assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError @require_torch def test_init_agent_with_different_toolsets(self): diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 8614302baa..feb923af28 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -214,7 +214,7 @@ recur_fibo(6)""" def test_access_attributes(self): code = "integer = 1\nobj_class = integer.__class__\nobj_class" result = evaluate_python_code(code, {}, state={}) - assert result == int + assert result is int def test_list_comprehension(self): code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])" @@ -591,7 +591,7 @@ except ValueError as e: code = "type_a = float(2); type_b = str; type_c = int" state = {} result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) - assert result == int + assert result is int def test_tuple_id(self): code = """ diff --git a/tests/models/roformer/test_tokenization_roformer.py b/tests/models/roformer/test_tokenization_roformer.py index 2c5b9c65e9..6dfd0a385f 100644 --- a/tests/models/roformer/test_tokenization_roformer.py +++ b/tests/models/roformer/test_tokenization_roformer.py @@ -56,7 +56,7 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) - def test_rust_tokenizer(self): + def test_rust_tokenizer(self): # noqa: F811 tokenizer = self.get_rust_tokenizer() input_text, output_text = self.get_chinese_input_output_texts() tokens = tokenizer.tokenize(input_text)