[Flax] Allow dataclasses to be jitted (#11886)

* fix_torch_device_generate_test

* remove @

* change dataclasses to flax ones

* fix typo

* fix jitted tests

* fix bert & electra
This commit is contained in:
Patrick von Platen
2021-05-26 15:01:13 +01:00
committed by GitHub
parent e6126e1932
commit d5a72b6e19
4 changed files with 17 additions and 30 deletions

View File

@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import flax
import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput):
FlaxCausalLMOutput = FlaxMaskedLMOutput
@dataclass
@flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.

View File

@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
@dataclass
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPreTraining`.

View File

@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
@dataclass
@flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.ElectraForPreTraining`.

View File

@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
@jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()