Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more 🤗 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Initial working (forward + backward) for Flax MLM training example. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Simply code Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing comments, using module and moving to LM task. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore parameter name "module" wrongly renamed model. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore correct output ordering... Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Actually commit the example 😅 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make it possible to initialize the training from scratch Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reuse flax linen example of cross entropy loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added specific data collator for flax Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove todo for data collator Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added evaluation step Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable flax tensorboard output Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable jax.pmap support. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable bfloat16 with --fp16 cmdline args Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correctly export metrics to tensorboard Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added dropout and ability to use it. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Effectively enable & disable during training and evaluation steps. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Oops. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable specifying kernel initializer scale Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added warmup step to the learning rate scheduler. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix typo. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Print training loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix linter issue (flake8) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix model matching Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix dummies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix non default dtype on Flax models Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta attention as Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix copy Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Wording. Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
This commit is contained in:
@@ -12,13 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
@@ -101,8 +102,8 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
@@ -122,11 +123,13 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
||||
return y
|
||||
|
||||
|
||||
@@ -139,7 +142,9 @@ class FlaxRobertaEmbedding(nn.Module):
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
kernel_init_scale: float = 0.2
|
||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
@@ -155,66 +160,108 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
w_emb = FlaxRobertaEmbedding(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
p_emb = FlaxRobertaEmbedding(
|
||||
self.max_length,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
t_emb = FlaxRobertaEmbedding(
|
||||
self.type_vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||
class FlaxRobertaAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads,
|
||||
qkv_features=self.head_size,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=deterministic,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask)
|
||||
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
class FlaxRobertaIntermediate(nn.Module):
|
||||
output_size: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
dense = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state)
|
||||
return gelu(dense)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||
class FlaxRobertaOutput(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_state = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
@@ -222,14 +269,29 @@ class FlaxRobertaLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||
output = FlaxRobertaOutput(name="output")(intermediate, attention)
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxRobertaAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxRobertaIntermediate(
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
output = FlaxRobertaOutput(
|
||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||
)(intermediate, attention, deterministic=deterministic)
|
||||
|
||||
return output
|
||||
|
||||
@@ -244,9 +306,12 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
@@ -254,8 +319,16 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||
input_i = layer(input_i, attention_mask)
|
||||
layer = FlaxRobertaLayer(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
||||
return input_i
|
||||
|
||||
|
||||
@@ -265,22 +338,40 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxRobertaLayerCollection(
|
||||
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||
)(hidden_state, attention_mask)
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||
class FlaxRobertaPooler(nn.Module):
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
out = nn.Dense(
|
||||
hidden_state.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(cls_token)
|
||||
return nn.tanh(out)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
@@ -293,21 +384,38 @@ class FlaxRobertaModule(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||
)(embeddings, attention_mask)
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
pooled = FlaxRobertaPooler(name="pooler")(encoder)
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@@ -396,8 +504,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
|
||||
model = FlaxRobertaModule(
|
||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
||||
module = FlaxRobertaModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
@@ -406,31 +514,78 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
super().__init__(config, model, state, seed)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
super().__init__(config, module, state, seed)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(
|
||||
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
)
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return self.model.apply(
|
||||
{"params": self.params},
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
input_ids: jnp.ndarray
|
||||
padding_idx: int
|
||||
|
||||
Returns: jnp.ndarray
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
|
||||
Reference in New Issue
Block a user