[Flax] Allow dataclasses to be jitted (#11886)
* fix_torch_device_generate_test * remove @ * change dataclasses to flax ones * fix typo * fix jitted tests * fix bert & electra
This commit is contained in:
committed by
GitHub
parent
e6126e1932
commit
d5a72b6e19
@@ -11,16 +11,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import flax
|
||||||
import jaxlib.xla_extension as jax_xla
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
|
||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxBaseModelOutput(ModelOutput):
|
class FlaxBaseModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for model's outputs, with potential hidden states and attentions.
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxBaseModelOutputWithPast(ModelOutput):
|
class FlaxBaseModelOutputWithPast(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for model's outputs, with potential hidden states and attentions.
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||||
@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxMaskedLMOutput(ModelOutput):
|
class FlaxMaskedLMOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for masked language models outputs.
|
Base class for masked language models outputs.
|
||||||
@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput):
|
|||||||
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxNextSentencePredictorOutput(ModelOutput):
|
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxSequenceClassifierOutput(ModelOutput):
|
class FlaxSequenceClassifierOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of sentence classification models.
|
Base class for outputs of sentence classification models.
|
||||||
@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of multiple choice models.
|
Base class for outputs of multiple choice models.
|
||||||
@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxTokenClassifierOutput(ModelOutput):
|
class FlaxTokenClassifierOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of token classification models.
|
Base class for outputs of token classification models.
|
||||||
@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of question answering models.
|
Base class for outputs of question answering models.
|
||||||
|
|||||||
@@ -13,11 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import flax
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig"
|
|||||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxBertForPreTrainingOutput(ModelOutput):
|
class FlaxBertForPreTrainingOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Output type of :class:`~transformers.BertForPreTraining`.
|
Output type of :class:`~transformers.BertForPreTraining`.
|
||||||
|
|||||||
@@ -13,11 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import flax
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig"
|
|||||||
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxElectraForPreTrainingOutput(ModelOutput):
|
class FlaxElectraForPreTrainingOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Output type of :class:`~transformers.ElectraForPreTraining`.
|
Output type of :class:`~transformers.ElectraForPreTraining`.
|
||||||
|
|||||||
@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
|
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||||
|
|
||||||
with self.subTest("JIT Enabled"):
|
with self.subTest("JIT Enabled"):
|
||||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
with self.subTest("JIT Disabled"):
|
with self.subTest("JIT Disabled"):
|
||||||
with jax.disable_jit():
|
with jax.disable_jit():
|
||||||
outputs = model_jitted(**prepared_inputs_dict)
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
|
|
||||||
return model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# jitted function cannot return OrderedDict
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
model_jitted_return_dict(**prepared_inputs_dict)
|
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user