[Flax] Add other BERT classes (#10977)
* add first code structures * add all bert models * add to init and docs * correct docs * make style
This commit is contained in:
committed by
GitHub
parent
e031162a6b
commit
e87505f3a1
@@ -209,8 +209,50 @@ FlaxBertModel
|
|||||||
:members: __call__
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForPreTraining
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForPreTraining
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
FlaxBertForMaskedLM
|
FlaxBertForMaskedLM
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.FlaxBertForMaskedLM
|
.. autoclass:: transformers.FlaxBertForMaskedLM
|
||||||
:members: __call__
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForNextSentencePrediction
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForNextSentencePrediction
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForSequenceClassification
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForMultipleChoice
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForMultipleChoice
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForTokenClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForTokenClassification
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertForQuestionAnswering
|
||||||
|
:members: __call__
|
||||||
|
|||||||
@@ -1290,7 +1290,19 @@ else:
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||||
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
|
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"])
|
||||||
_import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"])
|
_import_structure["models.bert"].extend(
|
||||||
|
[
|
||||||
|
"FlaxBertForMaskedLM",
|
||||||
|
"FlaxBertForMultipleChoice",
|
||||||
|
"FlaxBertForNextSentencePrediction",
|
||||||
|
"FlaxBertForPreTraining",
|
||||||
|
"FlaxBertForQuestionAnswering",
|
||||||
|
"FlaxBertForSequenceClassification",
|
||||||
|
"FlaxBertForTokenClassification",
|
||||||
|
"FlaxBertModel",
|
||||||
|
"FlaxBertPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.roberta"].append("FlaxRobertaModel")
|
_import_structure["models.roberta"].append("FlaxRobertaModel")
|
||||||
else:
|
else:
|
||||||
from .utils import dummy_flax_objects
|
from .utils import dummy_flax_objects
|
||||||
@@ -2372,7 +2384,17 @@ if TYPE_CHECKING:
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||||
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
|
from .models.bert import (
|
||||||
|
FlaxBertForMaskedLM,
|
||||||
|
FlaxBertForMultipleChoice,
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
|
FlaxBertForTokenClassification,
|
||||||
|
FlaxBertModel,
|
||||||
|
FlaxBertPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.roberta import FlaxRobertaModel
|
from .models.roberta import FlaxRobertaModel
|
||||||
else:
|
else:
|
||||||
# Import the same objects as dummies to get them in the namespace.
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
|
|||||||
@@ -70,8 +70,17 @@ if is_tf_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_bert"] = ["FlaxBertForMaskedLM", "FlaxBertModel"]
|
_import_structure["modeling_flax_bert"] = [
|
||||||
|
"FlaxBertForMaskedLM",
|
||||||
|
"FlaxBertForMultipleChoice",
|
||||||
|
"FlaxBertForNextSentencePrediction",
|
||||||
|
"FlaxBertForPreTraining",
|
||||||
|
"FlaxBertForQuestionAnswering",
|
||||||
|
"FlaxBertForSequenceClassification",
|
||||||
|
"FlaxBertForTokenClassification",
|
||||||
|
"FlaxBertModel",
|
||||||
|
"FlaxBertPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||||
@@ -115,7 +124,17 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
from .modeling_flax_bert import (
|
||||||
|
FlaxBertForMaskedLM,
|
||||||
|
FlaxBertForMultipleChoice,
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
|
FlaxBertForTokenClassification,
|
||||||
|
FlaxBertModel,
|
||||||
|
FlaxBertPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -445,6 +445,30 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertOnlyNSPHead(nn.Module):
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, pooled_output):
|
||||||
|
return self.seq_relationship(pooled_output)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertPreTrainingHeads(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||||
|
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, pooled_output):
|
||||||
|
prediction_scores = self.predictions(hidden_states)
|
||||||
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||||
|
return prediction_scores, seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@@ -551,6 +575,73 @@ class FlaxBertModule(nn.Module):
|
|||||||
return hidden_states, pooled
|
return hidden_states, pooled
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
||||||
|
sentence prediction (classification)` head.
|
||||||
|
""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForPreTrainingModule(config, **kwargs)
|
||||||
|
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForPreTrainingModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||||
|
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
hidden_states, pooled_output = self.bert(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
|
)
|
||||||
|
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
|
||||||
|
|
||||||
|
return (prediction_scores, seq_relationship_score)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
@@ -559,6 +650,7 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
|||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -594,15 +686,8 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.bert = FlaxBertModule(
|
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||||
config=self.config,
|
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||||
add_pooling_layer=False,
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
|
||||||
self.cls = FlaxBertOnlyMLMHead(
|
|
||||||
config=self.config,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
@@ -611,7 +696,348 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
|
|
||||||
# Compute the prediction scores
|
# Compute the prediction scores
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
|
||||||
logits = self.cls(hidden_states)
|
logits = self.cls(hidden_states)
|
||||||
|
|
||||||
return (logits,)
|
return (logits,)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForNextSentencePredictionModule(config, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||||
|
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
_, pooled_output = self.bert(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_relationship_scores = self.cls(pooled_output)
|
||||||
|
return (seq_relationship_scores,)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
||||||
|
output) e.g. for GLUE tasks.
|
||||||
|
""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForSequenceClassificationModule(config, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Dense(
|
||||||
|
self.config.num_labels,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
_, pooled_output = self.bert(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
return (logits,)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||||
|
softmax) e.g. for RocStories/SWAG tasks.
|
||||||
|
""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForMultipleChoiceModule(config, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
num_choices = input_ids.shape[1]
|
||||||
|
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
||||||
|
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
||||||
|
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
||||||
|
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
||||||
|
|
||||||
|
# Model
|
||||||
|
_, pooled_output = self.bert(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
reshaped_logits = logits.reshape(-1, num_choices)
|
||||||
|
|
||||||
|
return (reshaped_logits,)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
||||||
|
Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForTokenClassificationModule(config, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
|
|
||||||
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
return (logits,)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||||
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
|
module = FlaxBertForQuestionAnsweringModule(config, **kwargs)
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: PRNGKey = None,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any PRNG if needed
|
||||||
|
rngs = {}
|
||||||
|
if dropout_rng is not None:
|
||||||
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
|
return self.module.apply(
|
||||||
|
{"params": params or self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
not train,
|
||||||
|
rngs=rngs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||||
|
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
|
|
||||||
|
logits = self.qa_outputs(hidden_states)
|
||||||
|
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1)
|
||||||
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
return (start_logits, end_logits)
|
||||||
|
|||||||
@@ -32,6 +32,52 @@ class FlaxBertForMaskedLM:
|
|||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForMultipleChoice:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForNextSentencePrediction:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForPreTraining:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForQuestionAnswering:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertForTokenClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertModel:
|
class FlaxBertModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
@@ -41,6 +87,15 @@ class FlaxBertModel:
|
|||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
class FlaxRobertaModel:
|
class FlaxRobertaModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_flax(self)
|
requires_flax(self)
|
||||||
|
|||||||
@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
from transformers.models.bert.modeling_flax_bert import (
|
||||||
|
FlaxBertForMaskedLM,
|
||||||
|
FlaxBertForMultipleChoice,
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForTokenClassification,
|
||||||
|
FlaxBertModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertModelTester(unittest.TestCase):
|
class FlaxBertModelTester(unittest.TestCase):
|
||||||
@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||||||
type_vocab_size=16,
|
type_vocab_size=16,
|
||||||
type_sequence_label_size=2,
|
type_sequence_label_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
num_choices=4,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||||||
self.type_vocab_size = type_vocab_size
|
self.type_vocab_size = type_vocab_size
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.num_choices = num_choices
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||||||
@require_flax
|
@require_flax
|
||||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
FlaxBertModel,
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
FlaxBertForMaskedLM,
|
||||||
|
FlaxBertForMultipleChoice,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
FlaxBertForTokenClassification,
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
)
|
||||||
|
if is_flax_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxBertModelTester(self)
|
self.model_tester = FlaxBertModelTester(self)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
@@ -65,6 +66,18 @@ class FlaxModelTesterMixin:
|
|||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
|
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
|
# hack for now until we have AutoModel classes
|
||||||
|
if "ForMultipleChoice" in model_class.__name__:
|
||||||
|
inputs_dict = {
|
||||||
|
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||||
|
for k, v in inputs_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
diff = np.abs((a - b)).max()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
@@ -75,6 +88,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
pt_model = pt_model_class(config).eval()
|
pt_model = pt_model_class(config).eval()
|
||||||
@@ -83,12 +97,12 @@ class FlaxModelTesterMixin:
|
|||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
|
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_outputs = fx_model(**inputs_dict)
|
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||||
@@ -97,7 +111,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
)
|
)
|
||||||
@@ -111,13 +125,14 @@ class FlaxModelTesterMixin:
|
|||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
outputs = model(**inputs_dict)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
outputs = model(**prepared_inputs_dict)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
outputs_loaded = model_loaded(**inputs_dict)
|
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||||
|
|
||||||
@@ -126,6 +141,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
@@ -134,10 +150,10 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
with self.subTest("JIT Disabled"):
|
with self.subTest("JIT Disabled"):
|
||||||
with jax.disable_jit():
|
with jax.disable_jit():
|
||||||
outputs = model_jitted(**inputs_dict)
|
outputs = model_jitted(**prepared_inputs_dict)
|
||||||
|
|
||||||
with self.subTest("JIT Enabled"):
|
with self.subTest("JIT Enabled"):
|
||||||
jitted_outputs = model_jitted(**inputs_dict)
|
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
|||||||
Reference in New Issue
Block a user