[Flax] Add other BERT classes (#10977)
* add first code structures * add all bert models * add to init and docs * correct docs * make style
This commit is contained in:
committed by
GitHub
parent
e031162a6b
commit
e87505f3a1
@@ -209,8 +209,50 @@ FlaxBertModel
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForPreTraining
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForMaskedLM
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForNextSentencePrediction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForNextSentencePrediction
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForSequenceClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForMultipleChoice
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForTokenClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBertForQuestionAnswering
|
||||
:members: __call__
|
||||
|
||||
@@ -1290,7 +1290,19 @@ else:
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
|
||||
_import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"])
|
||||
_import_structure["models.bert"].extend(
|
||||
[
|
||||
"FlaxBertForMaskedLM",
|
||||
"FlaxBertForMultipleChoice",
|
||||
"FlaxBertForNextSentencePrediction",
|
||||
"FlaxBertForPreTraining",
|
||||
"FlaxBertForQuestionAnswering",
|
||||
"FlaxBertForSequenceClassification",
|
||||
"FlaxBertForTokenClassification",
|
||||
"FlaxBertModel",
|
||||
"FlaxBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roberta"].append("FlaxRobertaModel")
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
@@ -2372,7 +2384,17 @@ if TYPE_CHECKING:
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
from .models.bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
FlaxBertPreTrainedModel,
|
||||
)
|
||||
from .models.roberta import FlaxRobertaModel
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
|
||||
@@ -70,8 +70,17 @@ if is_tf_available():
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_bert"] = ["FlaxBertForMaskedLM", "FlaxBertModel"]
|
||||
|
||||
_import_structure["modeling_flax_bert"] = [
|
||||
"FlaxBertForMaskedLM",
|
||||
"FlaxBertForMultipleChoice",
|
||||
"FlaxBertForNextSentencePrediction",
|
||||
"FlaxBertForPreTraining",
|
||||
"FlaxBertForQuestionAnswering",
|
||||
"FlaxBertForSequenceClassification",
|
||||
"FlaxBertForTokenClassification",
|
||||
"FlaxBertModel",
|
||||
"FlaxBertPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
@@ -115,7 +124,17 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
from .modeling_flax_bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
FlaxBertPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
||||
@@ -445,6 +445,30 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOnlyNSPHead(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||
|
||||
def __call__(self, pooled_output):
|
||||
return self.seq_relationship(pooled_output)
|
||||
|
||||
|
||||
class FlaxBertPreTrainingHeads(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, pooled_output):
|
||||
prediction_scores = self.predictions(hidden_states)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@@ -551,6 +575,73 @@ class FlaxBertModule(nn.Module):
|
||||
return hidden_states, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
||||
sentence prediction (classification)` head.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForPreTrainingModule(config, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForPreTrainingModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
hidden_states, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
|
||||
|
||||
return (prediction_scores, seq_relationship_score)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
@@ -559,6 +650,7 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -594,15 +686,8 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.cls = FlaxBertOnlyMLMHead(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
@@ -611,7 +696,348 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
# Compute the prediction scores
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.cls(hidden_states)
|
||||
|
||||
return (logits,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
return (seq_relationship_scores,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
||||
output) e.g. for GLUE tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
|
||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
return (logits,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
num_choices = input_ids.shape[1]
|
||||
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
||||
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
||||
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
||||
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
||||
|
||||
# Model
|
||||
_, pooled_output = self.bert(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
|
||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
reshaped_logits = logits.reshape(-1, num_choices)
|
||||
|
||||
return (reshaped_logits,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||
Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForTokenClassificationModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
return (logits,)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
train: bool = False,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
return (start_logits, end_logits)
|
||||
|
||||
@@ -32,6 +32,52 @@ class FlaxBertForMaskedLM:
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForNextSentencePrediction:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
@@ -41,6 +87,15 @@ class FlaxBertModel:
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxBertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
|
||||
class FlaxRobertaModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
from transformers.models.bert.modeling_flax_bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertModelTester(unittest.TestCase):
|
||||
@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_choices = num_choices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxBertModel,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertForQuestionAnswering,
|
||||
)
|
||||
if is_flax_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBertModelTester(self)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
@@ -65,6 +66,18 @@ class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
# hack for now until we have AutoModel classes
|
||||
if "ForMultipleChoice" in model_class.__name__:
|
||||
inputs_dict = {
|
||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
@@ -75,6 +88,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
@@ -83,12 +97,12 @@ class FlaxModelTesterMixin:
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||
@@ -97,7 +111,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
@@ -111,13 +125,14 @@ class FlaxModelTesterMixin:
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**inputs_dict)
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@@ -126,6 +141,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
@@ -134,10 +150,10 @@ class FlaxModelTesterMixin:
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**inputs_dict)
|
||||
outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**inputs_dict)
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
|
||||
Reference in New Issue
Block a user