From de2318894e4f971ea2273c653a702dc93db2bd6a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 25 Jul 2024 15:12:23 +0200 Subject: [PATCH] [warnings] fix E721 warnings (#32223) fix E721 warnings --- src/transformers/generation/candidate_generator.py | 2 +- src/transformers/models/bart/modeling_flax_bart.py | 2 +- src/transformers/models/esm/openfold_utils/chunk_utils.py | 2 +- src/transformers/models/mbart/modeling_flax_mbart.py | 2 +- src/transformers/trainer_pt_utils.py | 2 +- src/transformers/utils/chat_template_utils.py | 2 +- src/transformers/utils/generic.py | 2 +- tests/models/ibert/test_modeling_ibert.py | 4 ++-- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e735d0a2ca..39fa67bfaf 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator): self.generation_config.min_length = 0 self.generation_config.min_new_tokens = None for processor in self.logits_processor: - if type(processor) == MinLengthLogitsProcessor: + if isinstance(processor, MinLengthLogitsProcessor): raise ValueError( "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. " "Please pass in `min_length` into `.generate()` instead" diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 507a93a8e7..634c256fe7 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module): eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation - if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: + if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer): if len(jnp.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/esm/openfold_utils/chunk_utils.py b/src/transformers/models/esm/openfold_utils/chunk_utils.py index 16131b8590..51ff6b74d6 100644 --- a/src/transformers/models/esm/openfold_utils/chunk_utils.py +++ b/src/transformers/models/esm/openfold_utils/chunk_utils.py @@ -356,7 +356,7 @@ class ChunkSizeTuner: def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool: consistent = True for a1, a2 in zip(ac1, ac2): - assert type(ac1) == type(ac2) + assert type(ac1) is type(ac2) if isinstance(ac1, (list, tuple)): consistent &= self._compare_arg_caches(a1, a2) elif isinstance(ac1, dict): diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 0f943df13c..83e4dcaee2 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module): eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation - if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: + if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer): if len(jnp.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index a3c2db27d2..5f78860fe6 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100): """ if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)): assert ( - type(tensors) == type(new_tensors) + type(tensors) is type(new_tensors) ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." if isinstance(tensors, (list, tuple)): return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors)) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 414d2fb724..078a307b1c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict: elif origin is Union: # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end - subtypes = [_parse_type_hint(t) for t in args if t != type(None)] + subtypes = [_parse_type_hint(t) for t in args if t is not type(None)] if len(subtypes) == 1: # A single non-null type can be expressed directly return_dict = subtypes[0] diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 01c5ede34a..3bd5fa8cc6 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x): # the `is_symbolic_tensor` predicate is only available starting with TF 2.14 if hasattr(tf, "is_symbolic_tensor"): return tf.is_symbolic_tensor(x) - return type(x) == tf.Tensor + return isinstance(x, tf.Tensor) def is_tf_symbolic_tensor(x): diff --git a/tests/models/ibert/test_modeling_ibert.py b/tests/models/ibert/test_modeling_ibert.py index b9b5054d90..3918b3efea 100644 --- a/tests/models/ibert/test_modeling_ibert.py +++ b/tests/models/ibert/test_modeling_ibert.py @@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase): # Recursively convert all the `quant_mode` attributes as `True` if hasattr(model, "quant_mode"): model.quant_mode = True - elif type(model) == nn.Sequential: + elif isinstance(model, nn.Sequential): for n, m in model.named_children(): self.quantize(m) - elif type(model) == nn.ModuleList: + elif isinstance(model, nn.ModuleList): for n in model: self.quantize(n) else: