diff --git a/setup.py b/setup.py index e194982185..5f1e40f767 100644 --- a/setup.py +++ b/setup.py @@ -121,8 +121,8 @@ _deps = [ "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", - "jax>=0.4.1,<=0.4.13", - "jaxlib>=0.4.1,<=0.4.13", + "jax>=0.4.27,<=0.4.38", + "jaxlib>=0.4.27,<=0.4.38", "jieba", "jinja2>=3.1.0", "kenlm", diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index 4162f21e95..ea30526a24 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -129,7 +129,7 @@ class EnvironmentCommand(BaseTransformersCLICommand): flax_version = flax.__version__ jax_version = jax.__version__ jaxlib_version = jaxlib.__version__ - jax_backend = jax.lib.xla_bridge.get_backend().platform + jax_backend = jax.default_backend() info = { "`transformers` version": version, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 3fc3aafdbe..f63990f771 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -28,8 +28,8 @@ deps = { "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.4.1,<=0.4.13", - "jaxlib": "jaxlib>=0.4.1,<=0.4.13", + "jax": "jax>=0.4.27,<=0.4.38", + "jaxlib": "jaxlib>=0.4.27,<=0.4.38", "jieba": "jieba", "jinja2": "jinja2>=3.1.0", "kenlm": "kenlm", diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 7c5fdf9c17..4b2378cc38 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -387,7 +387,7 @@ class FlaxLongT5Attention(nn.Module): relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: - relative_position = -jnp.clip(relative_position, a_max=0) + relative_position = -jnp.clip(relative_position, max=0) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -398,7 +398,7 @@ class FlaxLongT5Attention(nn.Module): relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1) relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) @@ -672,7 +672,7 @@ class FlaxLongT5LocalAttention(nn.Module): relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: - relative_position = -jnp.clip(relative_position, a_max=0) + relative_position = -jnp.clip(relative_position, max=0) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -683,7 +683,7 @@ class FlaxLongT5LocalAttention(nn.Module): relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1) relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) @@ -895,7 +895,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module): relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: - relative_position = -jnp.clip(relative_position, a_max=0) + relative_position = -jnp.clip(relative_position, max=0) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -906,7 +906,7 @@ class FlaxLongT5TransientGlobalAttention(nn.Module): relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1) relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index be76fe1b77..1eb6d208a9 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -247,7 +247,7 @@ class FlaxT5Attention(nn.Module): relative_buckets += (relative_position > 0) * num_buckets relative_position = jnp.abs(relative_position) else: - relative_position = -jnp.clip(relative_position, a_max=0) + relative_position = -jnp.clip(relative_position, max=0) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -258,7 +258,7 @@ class FlaxT5Attention(nn.Module): relative_position_if_large = max_exact + ( jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1) relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)