[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)

* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Patrick von Platen
2020-12-16 13:03:32 +01:00
committed by GitHub
parent 51adb97cd6
commit 640e6fe190
14 changed files with 700 additions and 359 deletions

View File

@@ -19,10 +19,11 @@ import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...utils import logging
from .configuration_roberta import RobertaConfig
@@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
ROBERTA_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
@@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
@@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module):
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_state, attention_mask)
)(hidden_states, attention_mask)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
return layer_norm
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
# TODO: Add ACT2FN reference to change activation function
dense = nn.Dense(
def __call__(self, hidden_states):
hidden_states = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_state)
return gelu(dense)
)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
@@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module):
@nn.compact
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_state = nn.Dense(
hidden_states = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
return hidden_state
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
return hidden_states
class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(
self.num_heads,
self.head_size,
@@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module):
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
@@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module):
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
@@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_state, attention_mask, deterministic=deterministic)
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
@@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_state):
cls_token = hidden_state[:, 0]
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_state.shape[-1],
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
name="dense",
dtype=self.dtype,
@@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module):
return nn.tanh(out)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxPreTrainedModel):
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
model_class = FlaxRobertaModule
config_class = RobertaConfig
base_model_prefix = "roberta"
@@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
return jax_state
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
@add_start_docstrings(
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING,
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
"""
def __init__(
self,
config: RobertaConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
@@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
**kwargs,
)
super().__init__(config, module, state, seed)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(
@@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
rngs=rngs,
)
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@nn.compact
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
if position_ids is None:
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
if not self.add_pooling_layer:
return encoder
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
return input_ids, attention_mask, token_type_ids, position_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: jnp.ndarray
padding_idx: int
Returns: jnp.ndarray
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = (input_ids != padding_idx).astype("i4")
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
return incremental_indices.astype("i4") + padding_idx
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled