[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -19,10 +19,11 @@ import numpy as np
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
@@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig"
|
||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
input_ids: jnp.ndarray
|
||||
padding_idx: int
|
||||
|
||||
Returns: jnp.ndarray
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
|
||||
|
||||
ROBERTA_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
@@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
@@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module):
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask)
|
||||
)(hidden_states, attention_mask)
|
||||
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||
return layer_norm
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
class FlaxRobertaIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state)
|
||||
return gelu(dense)
|
||||
)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||
@@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_state = nn.Dense(
|
||||
hidden_states = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxRobertaLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxRobertaAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
@@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module):
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxRobertaIntermediate(
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
@@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
@@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
@@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxRobertaLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
|
||||
|
||||
@@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
def __call__(self, hidden_states):
|
||||
cls_token = hidden_states[:, 0]
|
||||
out = nn.Dense(
|
||||
hidden_state.shape[-1],
|
||||
hidden_states.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
@@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module):
|
||||
return nn.tanh(out)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
model_class = FlaxRobertaModule
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
@@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RobertaConfig,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
module = FlaxRobertaModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
hidden_act=config.hidden_act,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, state, seed)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
@@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return input_ids, attention_mask, token_type_ids, position_ids
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
input_ids: jnp.ndarray
|
||||
padding_idx: int
|
||||
|
||||
Returns: jnp.ndarray
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = (input_ids != padding_idx).astype("i4")
|
||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||
return incremental_indices.astype("i4") + padding_idx
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
Reference in New Issue
Block a user