From e87505f3a1a6f1b301f16ecb8522451eee13e726 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Mar 2021 09:45:58 +0300 Subject: [PATCH] [Flax] Add other BERT classes (#10977) * add first code structures * add all bert models * add to init and docs * correct docs * make style --- docs/source/model_doc/bert.rst | 42 ++ src/transformers/__init__.py | 26 +- src/transformers/models/bert/__init__.py | 25 +- .../models/bert/modeling_flax_bert.py | 446 +++++++++++++++++- src/transformers/utils/dummy_flax_objects.py | 55 +++ tests/test_modeling_flax_bert.py | 27 +- tests/test_modeling_flax_common.py | 30 +- 7 files changed, 627 insertions(+), 24 deletions(-) diff --git a/docs/source/model_doc/bert.rst b/docs/source/model_doc/bert.rst index 0ed892783c..881060df18 100644 --- a/docs/source/model_doc/bert.rst +++ b/docs/source/model_doc/bert.rst @@ -209,8 +209,50 @@ FlaxBertModel :members: __call__ +FlaxBertForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForPreTraining + :members: __call__ + + FlaxBertForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.FlaxBertForMaskedLM :members: __call__ + + +FlaxBertForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForNextSentencePrediction + :members: __call__ + + +FlaxBertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForSequenceClassification + :members: __call__ + + +FlaxBertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForMultipleChoice + :members: __call__ + + +FlaxBertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForTokenClassification + :members: __call__ + + +FlaxBertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxBertForQuestionAnswering + :members: __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ac7a7690dd..39b65b70b7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1290,7 +1290,19 @@ else: if is_flax_available(): _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"]) - _import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"]) + _import_structure["models.bert"].extend( + [ + "FlaxBertForMaskedLM", + "FlaxBertForMultipleChoice", + "FlaxBertForNextSentencePrediction", + "FlaxBertForPreTraining", + "FlaxBertForQuestionAnswering", + "FlaxBertForSequenceClassification", + "FlaxBertForTokenClassification", + "FlaxBertModel", + "FlaxBertPreTrainedModel", + ] + ) _import_structure["models.roberta"].append("FlaxRobertaModel") else: from .utils import dummy_flax_objects @@ -2372,7 +2384,17 @@ if TYPE_CHECKING: if is_flax_available(): from .modeling_flax_utils import FlaxPreTrainedModel from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel - from .models.bert import FlaxBertForMaskedLM, FlaxBertModel + from .models.bert import ( + FlaxBertForMaskedLM, + FlaxBertForMultipleChoice, + FlaxBertForNextSentencePrediction, + FlaxBertForPreTraining, + FlaxBertForQuestionAnswering, + FlaxBertForSequenceClassification, + FlaxBertForTokenClassification, + FlaxBertModel, + FlaxBertPreTrainedModel, + ) from .models.roberta import FlaxRobertaModel else: # Import the same objects as dummies to get them in the namespace. diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py index 6f99979ad6..ad03369646 100644 --- a/src/transformers/models/bert/__init__.py +++ b/src/transformers/models/bert/__init__.py @@ -70,8 +70,17 @@ if is_tf_available(): ] if is_flax_available(): - _import_structure["modeling_flax_bert"] = ["FlaxBertForMaskedLM", "FlaxBertModel"] - + _import_structure["modeling_flax_bert"] = [ + "FlaxBertForMaskedLM", + "FlaxBertForMultipleChoice", + "FlaxBertForNextSentencePrediction", + "FlaxBertForPreTraining", + "FlaxBertForQuestionAnswering", + "FlaxBertForSequenceClassification", + "FlaxBertForTokenClassification", + "FlaxBertModel", + "FlaxBertPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig @@ -115,7 +124,17 @@ if TYPE_CHECKING: ) if is_flax_available(): - from .modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel + from .modeling_flax_bert import ( + FlaxBertForMaskedLM, + FlaxBertForMultipleChoice, + FlaxBertForNextSentencePrediction, + FlaxBertForPreTraining, + FlaxBertForQuestionAnswering, + FlaxBertForSequenceClassification, + FlaxBertForTokenClassification, + FlaxBertModel, + FlaxBertPreTrainedModel, + ) else: import importlib diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 8a37721d7e..52924de812 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -445,6 +445,30 @@ class FlaxBertOnlyMLMHead(nn.Module): return hidden_states +class FlaxBertOnlyNSPHead(nn.Module): + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.seq_relationship = nn.Dense(2, dtype=self.dtype) + + def __call__(self, pooled_output): + return self.seq_relationship(pooled_output) + + +class FlaxBertPreTrainingHeads(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) + self.seq_relationship = nn.Dense(2, dtype=self.dtype) + + def __call__(self, hidden_states, pooled_output): + prediction_scores = self.predictions(hidden_states) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + class FlaxBertPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -551,6 +575,73 @@ class FlaxBertModule(nn.Module): return hidden_states, pooled +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForPreTraining(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForPreTrainingModule(config, **kwargs) + + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForPreTrainingModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states, pooled_output = self.bert( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output) + + return (prediction_scores, seq_relationship_score) + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): def __init__( self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs @@ -559,6 +650,7 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, input_ids, @@ -594,15 +686,8 @@ class FlaxBertForMaskedLMModule(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.bert = FlaxBertModule( - config=self.config, - add_pooling_layer=False, - ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.cls = FlaxBertOnlyMLMHead( - config=self.config, - dtype=self.dtype, - ) + self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True @@ -611,7 +696,348 @@ class FlaxBertForMaskedLMModule(nn.Module): hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) # Compute the prediction scores - hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.cls(hidden_states) return (logits,) + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, +) +class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForNextSentencePredictionModule(config, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForNextSentencePredictionModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + _, pooled_output = self.bert( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + seq_relationship_scores = self.cls(pooled_output) + return (seq_relationship_scores,) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForSequenceClassificationModule(config, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForSequenceClassificationModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + _, pooled_output = self.bert( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + return (logits,) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForMultipleChoiceModule(config, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForMultipleChoiceModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + _, pooled_output = self.bert( + input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + ) + + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + return (reshaped_logits,) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForTokenClassificationModule(config, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForTokenClassificationModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + return (logits,) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): + def __init__( + self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + ): + module = FlaxBertForQuestionAnsweringModule(config, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: PRNGKey = None, + train: bool = False, + ): + input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( + input_ids, attention_mask, token_type_ids, position_ids + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + rngs=rngs, + ) + + +class FlaxBertForQuestionAnsweringModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + ): + # Model + hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + return (start_logits, end_logits) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 00773af271..deea31820f 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -32,6 +32,52 @@ class FlaxBertForMaskedLM: requires_flax(self) +class FlaxBertForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxBertForNextSentencePrediction: + def __init__(self, *args, **kwargs): + requires_flax(self) + + +class FlaxBertForPreTraining: + def __init__(self, *args, **kwargs): + requires_flax(self) + + +class FlaxBertForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxBertForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxBertForTokenClassification: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + class FlaxBertModel: def __init__(self, *args, **kwargs): requires_flax(self) @@ -41,6 +87,15 @@ class FlaxBertModel: requires_flax(self) +class FlaxBertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + class FlaxRobertaModel: def __init__(self, *args, **kwargs): requires_flax(self) diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index c9946021f2..fc339f7501 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_ if is_flax_available(): - from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel + from transformers.models.bert.modeling_flax_bert import ( + FlaxBertForMaskedLM, + FlaxBertForMultipleChoice, + FlaxBertForNextSentencePrediction, + FlaxBertForPreTraining, + FlaxBertForQuestionAnswering, + FlaxBertForTokenClassification, + FlaxBertModel, + ) class FlaxBertModelTester(unittest.TestCase): @@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase): type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, + num_choices=4, ): self.parent = parent self.batch_size = batch_size @@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase): self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + self.num_choices = num_choices def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase): @require_flax class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else () + all_model_classes = ( + ( + FlaxBertModel, + FlaxBertForPreTraining, + FlaxBertForMaskedLM, + FlaxBertForMultipleChoice, + FlaxBertForQuestionAnswering, + FlaxBertForNextSentencePrediction, + FlaxBertForTokenClassification, + FlaxBertForQuestionAnswering, + ) + if is_flax_available() + else () + ) def setUp(self): self.model_tester = FlaxBertModelTester(self) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index afa436a9cf..462ac4d01d 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import random import tempfile @@ -65,6 +66,18 @@ class FlaxModelTesterMixin: model_tester = None all_model_classes = () + def _prepare_for_class(self, inputs_dict, model_class): + inputs_dict = copy.deepcopy(inputs_dict) + + # hack for now until we have AutoModel classes + if "ForMultipleChoice" in model_class.__name__: + inputs_dict = { + k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) + for k, v in inputs_dict.items() + } + + return inputs_dict + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") @@ -75,6 +88,7 @@ class FlaxModelTesterMixin: for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() @@ -83,12 +97,12 @@ class FlaxModelTesterMixin: fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()} + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_outputs = fx_model(**inputs_dict) + fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3) @@ -97,7 +111,7 @@ class FlaxModelTesterMixin: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - fx_outputs_loaded = fx_model_loaded(**inputs_dict) + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) @@ -111,13 +125,14 @@ class FlaxModelTesterMixin: with self.subTest(model_class.__name__): model = model_class(config) - outputs = model(**inputs_dict) + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**prepared_inputs_dict) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_loaded = model_class.from_pretrained(tmpdirname) - outputs_loaded = model_loaded(**inputs_dict) + outputs_loaded = model_loaded(**prepared_inputs_dict) for output_loaded, output in zip(outputs_loaded, outputs): self.assert_almost_equals(output_loaded, output, 5e-3) @@ -126,6 +141,7 @@ class FlaxModelTesterMixin: for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) @jax.jit @@ -134,10 +150,10 @@ class FlaxModelTesterMixin: with self.subTest("JIT Disabled"): with jax.disable_jit(): - outputs = model_jitted(**inputs_dict) + outputs = model_jitted(**prepared_inputs_dict) with self.subTest("JIT Enabled"): - jitted_outputs = model_jitted(**inputs_dict) + jitted_outputs = model_jitted(**prepared_inputs_dict) self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs):