@@ -162,7 +162,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
self.generation_config.min_length = 0
|
self.generation_config.min_length = 0
|
||||||
self.generation_config.min_new_tokens = None
|
self.generation_config.min_new_tokens = None
|
||||||
for processor in self.logits_processor:
|
for processor in self.logits_processor:
|
||||||
if type(processor) == MinLengthLogitsProcessor:
|
if isinstance(processor, MinLengthLogitsProcessor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
|
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
|
||||||
"Please pass in `min_length` into `.generate()` instead"
|
"Please pass in `min_length` into `.generate()` instead"
|
||||||
|
|||||||
@@ -1599,7 +1599,7 @@ class FlaxBartForSequenceClassificationModule(nn.Module):
|
|||||||
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
|
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
|
||||||
|
|
||||||
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
|
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
|
||||||
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
|
if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
|
||||||
if len(jnp.unique(eos_mask.sum(1))) > 1:
|
if len(jnp.unique(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|
||||||
|
|||||||
@@ -356,7 +356,7 @@ class ChunkSizeTuner:
|
|||||||
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
|
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
|
||||||
consistent = True
|
consistent = True
|
||||||
for a1, a2 in zip(ac1, ac2):
|
for a1, a2 in zip(ac1, ac2):
|
||||||
assert type(ac1) == type(ac2)
|
assert type(ac1) is type(ac2)
|
||||||
if isinstance(ac1, (list, tuple)):
|
if isinstance(ac1, (list, tuple)):
|
||||||
consistent &= self._compare_arg_caches(a1, a2)
|
consistent &= self._compare_arg_caches(a1, a2)
|
||||||
elif isinstance(ac1, dict):
|
elif isinstance(ac1, dict):
|
||||||
|
|||||||
@@ -1635,7 +1635,7 @@ class FlaxMBartForSequenceClassificationModule(nn.Module):
|
|||||||
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
|
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
|
||||||
|
|
||||||
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
|
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
|
||||||
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
|
if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
|
||||||
if len(jnp.unique(eos_mask.sum(1))) > 1:
|
if len(jnp.unique(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
|
|||||||
"""
|
"""
|
||||||
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
|
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
|
||||||
assert (
|
assert (
|
||||||
type(tensors) == type(new_tensors)
|
type(tensors) is type(new_tensors)
|
||||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||||
if isinstance(tensors, (list, tuple)):
|
if isinstance(tensors, (list, tuple)):
|
||||||
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ def _parse_type_hint(hint: str) -> Dict:
|
|||||||
|
|
||||||
elif origin is Union:
|
elif origin is Union:
|
||||||
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
||||||
subtypes = [_parse_type_hint(t) for t in args if t != type(None)]
|
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
|
||||||
if len(subtypes) == 1:
|
if len(subtypes) == 1:
|
||||||
# A single non-null type can be expressed directly
|
# A single non-null type can be expressed directly
|
||||||
return_dict = subtypes[0]
|
return_dict = subtypes[0]
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ def _is_tf_symbolic_tensor(x):
|
|||||||
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
|
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
|
||||||
if hasattr(tf, "is_symbolic_tensor"):
|
if hasattr(tf, "is_symbolic_tensor"):
|
||||||
return tf.is_symbolic_tensor(x)
|
return tf.is_symbolic_tensor(x)
|
||||||
return type(x) == tf.Tensor
|
return isinstance(x, tf.Tensor)
|
||||||
|
|
||||||
|
|
||||||
def is_tf_symbolic_tensor(x):
|
def is_tf_symbolic_tensor(x):
|
||||||
|
|||||||
@@ -684,10 +684,10 @@ class IBertModelIntegrationTest(unittest.TestCase):
|
|||||||
# Recursively convert all the `quant_mode` attributes as `True`
|
# Recursively convert all the `quant_mode` attributes as `True`
|
||||||
if hasattr(model, "quant_mode"):
|
if hasattr(model, "quant_mode"):
|
||||||
model.quant_mode = True
|
model.quant_mode = True
|
||||||
elif type(model) == nn.Sequential:
|
elif isinstance(model, nn.Sequential):
|
||||||
for n, m in model.named_children():
|
for n, m in model.named_children():
|
||||||
self.quantize(m)
|
self.quantize(m)
|
||||||
elif type(model) == nn.ModuleList:
|
elif isinstance(model, nn.ModuleList):
|
||||||
for n in model:
|
for n in model:
|
||||||
self.quantize(n)
|
self.quantize(n)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user