[Flax] Allow dataclasses to be jitted (#11886)
* fix_torch_device_generate_test * remove @ * change dataclasses to flax ones * fix typo * fix jitted tests * fix bert & electra
This commit is contained in:
committed by
GitHub
parent
e6126e1932
commit
d5a72b6e19
@@ -11,16 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import flax
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxMaskedLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for masked language models outputs.
|
||||
@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput):
|
||||
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxSequenceClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of multiple choice models.
|
||||
@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxTokenClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of token classification models.
|
||||
@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of question answering models.
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxBertForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.BertForPreTraining`.
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig"
|
||||
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxElectraForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.ElectraForPreTraining`.
|
||||
|
||||
@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict)
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# jitted function cannot return OrderedDict
|
||||
with self.assertRaises(TypeError):
|
||||
model_jitted_return_dict(**prepared_inputs_dict)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user