From d5a72b6e19e9b594767f1046bc9ceec997cd69a9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 May 2021 15:01:13 +0100 Subject: [PATCH] [Flax] Allow dataclasses to be jitted (#11886) * fix_torch_device_generate_test * remove @ * change dataclasses to flax ones * fix typo * fix jitted tests * fix bert & electra --- src/transformers/modeling_flax_outputs.py | 21 +++++++++---------- .../models/bert/modeling_flax_bert.py | 4 ++-- .../models/electra/modeling_flax_electra.py | 4 ++-- tests/test_modeling_flax_common.py | 18 +++------------- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index a007ab7733..e8ad237723 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from dataclasses import dataclass from typing import Dict, Optional, Tuple +import flax import jaxlib.xla_extension as jax_xla from .file_utils import ModelOutput -@dataclass +@flax.struct.dataclass class FlaxBaseModelOutput(ModelOutput): """ Base class for model's outputs, with potential hidden states and attentions. @@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxBaseModelOutputWithPast(ModelOutput): """ Base class for model's outputs, with potential hidden states and attentions. @@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxBaseModelOutputWithPooling(ModelOutput): """ Base class for model's outputs that also contains a pooling of the last hidden states. @@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxMaskedLMOutput(ModelOutput): """ Base class for masked language models outputs. @@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput): FlaxCausalLMOutput = FlaxMaskedLMOutput -@dataclass +@flax.struct.dataclass class FlaxNextSentencePredictorOutput(ModelOutput): """ Base class for outputs of models predicting if two sentences are consecutive or not. @@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxSequenceClassifierOutput(ModelOutput): """ Base class for outputs of sentence classification models. @@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxMultipleChoiceModelOutput(ModelOutput): """ Base class for outputs of multiple choice models. @@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxTokenClassifierOutput(ModelOutput): """ Base class for outputs of token classification models. @@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput): attentions: Optional[Tuple[jax_xla.DeviceArray]] = None -@dataclass +@flax.struct.dataclass class FlaxQuestionAnsweringModelOutput(ModelOutput): """ Base class for outputs of question answering models. diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index d0b4568903..82ce4ee870 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Callable, Optional, Tuple import numpy as np +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" -@dataclass +@flax.struct.dataclass class FlaxBertForPreTrainingOutput(ModelOutput): """ Output type of :class:`~transformers.BertForPreTraining`. diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index cf36715108..9d94433016 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Callable, Optional, Tuple import numpy as np +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig" _TOKENIZER_FOR_DOC = "ElectraTokenizer" -@dataclass +@flax.struct.dataclass class FlaxElectraForPreTrainingOutput(ModelOutput): """ Output type of :class:`~transformers.ElectraForPreTraining`. diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b1dc6bf0af..e1c0322699 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -248,31 +248,19 @@ class FlaxModelTesterMixin: @jax.jit def model_jitted(input_ids, attention_mask=None, **kwargs): - return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple() + return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) with self.subTest("JIT Enabled"): - jitted_outputs = model_jitted(**prepared_inputs_dict) + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() with self.subTest("JIT Disabled"): with jax.disable_jit(): - outputs = model_jitted(**prepared_inputs_dict) + outputs = model_jitted(**prepared_inputs_dict).to_tuple() self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - @jax.jit - def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs): - return model( - input_ids=input_ids, - attention_mask=attention_mask, - **kwargs, - ) - - # jitted function cannot return OrderedDict - with self.assertRaises(TypeError): - model_jitted_return_dict(**prepared_inputs_dict) - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()