Flax T5 (#12150)
* copy pytorch-t5 * init * boom boom * forward pass same * make generation work * add more tests * make test work * finish normal tests * make fix-copies * finish quality * correct slow example * correct slow test * version table * upload models * Update tests/test_modeling_flax_t5.py * correct incorrectly deleted line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
@@ -396,7 +396,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|||||||
@@ -160,3 +160,15 @@ TFT5EncoderModel
|
|||||||
|
|
||||||
.. autoclass:: transformers.TFT5EncoderModel
|
.. autoclass:: transformers.TFT5EncoderModel
|
||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
FlaxT5Model
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxT5Model
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|
||||||
|
FlaxT5ForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -114,6 +114,7 @@ _deps = [
|
|||||||
"onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools>=1.4.2",
|
||||||
"onnxruntime>=1.4.0",
|
"onnxruntime>=1.4.0",
|
||||||
"optuna",
|
"optuna",
|
||||||
|
"optax>=0.0.8",
|
||||||
"packaging",
|
"packaging",
|
||||||
"parameterized",
|
"parameterized",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
@@ -234,7 +235,7 @@ if os.name == "nt": # windows
|
|||||||
extras["flax"] = [] # jax is not supported on windows
|
extras["flax"] = [] # jax is not supported on windows
|
||||||
else:
|
else:
|
||||||
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
||||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
|
||||||
|
|
||||||
extras["tokenizers"] = deps_list("tokenizers")
|
extras["tokenizers"] = deps_list("tokenizers")
|
||||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||||
|
|||||||
@@ -1597,6 +1597,7 @@ if is_flax_available():
|
|||||||
"FlaxRobertaPreTrainedModel",
|
"FlaxRobertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
|
||||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
||||||
else:
|
else:
|
||||||
from .utils import dummy_flax_objects
|
from .utils import dummy_flax_objects
|
||||||
@@ -2920,6 +2921,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxRobertaModel,
|
FlaxRobertaModel,
|
||||||
FlaxRobertaPreTrainedModel,
|
FlaxRobertaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
||||||
else:
|
else:
|
||||||
# Import the same objects as dummies to get them in the namespace.
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ deps = {
|
|||||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||||
"onnxruntime": "onnxruntime>=1.4.0",
|
"onnxruntime": "onnxruntime>=1.4.0",
|
||||||
"optuna": "optuna",
|
"optuna": "optuna",
|
||||||
|
"optax": "optax>=0.0.8",
|
||||||
"packaging": "packaging",
|
"packaging": "packaging",
|
||||||
"parameterized": "parameterized",
|
"parameterized": "parameterized",
|
||||||
"protobuf": "protobuf",
|
"protobuf": "protobuf",
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from ..roberta.modeling_flax_roberta import (
|
|||||||
FlaxRobertaForTokenClassification,
|
FlaxRobertaForTokenClassification,
|
||||||
FlaxRobertaModel,
|
FlaxRobertaModel,
|
||||||
)
|
)
|
||||||
|
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||||
from .auto_factory import auto_class_factory
|
from .auto_factory import auto_class_factory
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
@@ -72,6 +73,7 @@ from .configuration_auto import (
|
|||||||
ElectraConfig,
|
ElectraConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
|
T5Config,
|
||||||
ViTConfig,
|
ViTConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,6 +92,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
|||||||
(ElectraConfig, FlaxElectraModel),
|
(ElectraConfig, FlaxElectraModel),
|
||||||
(CLIPConfig, FlaxCLIPModel),
|
(CLIPConfig, FlaxCLIPModel),
|
||||||
(ViTConfig, FlaxViTModel),
|
(ViTConfig, FlaxViTModel),
|
||||||
|
(T5Config, FlaxT5Model),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,6 +104,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
(BigBirdConfig, FlaxBigBirdForPreTraining),
|
(BigBirdConfig, FlaxBigBirdForPreTraining),
|
||||||
(BartConfig, FlaxBartForConditionalGeneration),
|
(BartConfig, FlaxBartForConditionalGeneration),
|
||||||
(ElectraConfig, FlaxElectraForPreTraining),
|
(ElectraConfig, FlaxElectraForPreTraining),
|
||||||
|
(T5Config, FlaxT5ForConditionalGeneration),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,6 +119,14 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Seq2Seq Causal LM mapping
|
||||||
|
(BartConfig, FlaxBartForConditionalGeneration),
|
||||||
|
(T5Config, FlaxT5ForConditionalGeneration),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Image-classsification
|
# Model for Image-classsification
|
||||||
@@ -234,3 +246,9 @@ FlaxAutoModelForNextSentencePrediction = auto_class_factory(
|
|||||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
head_doc="next sentence prediction",
|
head_doc="next sentence prediction",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||||
|
"FlaxAutoModelForSeq2SeqLM",
|
||||||
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
head_doc="sequence-to-sequence language modeling",
|
||||||
|
)
|
||||||
|
|||||||
@@ -229,7 +229,6 @@ class FlaxBartAttention(nn.Module):
|
|||||||
embed_dim: int
|
embed_dim: int
|
||||||
num_heads: int
|
num_heads: int
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
is_decoder: bool = False
|
|
||||||
causal: bool = False
|
causal: bool = False
|
||||||
bias: bool = True
|
bias: bool = True
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
@@ -510,7 +509,6 @@ class FlaxBartDecoderLayer(nn.Module):
|
|||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=self.config.decoder_attention_heads,
|
num_heads=self.config.decoder_attention_heads,
|
||||||
dropout=self.config.attention_dropout,
|
dropout=self.config.attention_dropout,
|
||||||
is_decoder=True,
|
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||||
@@ -523,7 +521,6 @@ class FlaxBartDecoderLayer(nn.Module):
|
|||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=self.config.decoder_attention_heads,
|
num_heads=self.config.decoder_attention_heads,
|
||||||
dropout=self.config.attention_dropout,
|
dropout=self.config.attention_dropout,
|
||||||
is_decoder=True,
|
|
||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
_BaseLazyModule,
|
_BaseLazyModule,
|
||||||
|
is_flax_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
@@ -56,6 +57,13 @@ if is_tf_available():
|
|||||||
"TFT5PreTrainedModel",
|
"TFT5PreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
_import_structure["modeling_flax_t5"] = [
|
||||||
|
"FlaxT5ForConditionalGeneration",
|
||||||
|
"FlaxT5Model",
|
||||||
|
"FlaxT5PreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
@@ -85,6 +93,10 @@ if TYPE_CHECKING:
|
|||||||
TFT5PreTrainedModel,
|
TFT5PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|||||||
1584
src/transformers/models/t5/modeling_flax_t5.py
Normal file
1584
src/transformers/models/t5/modeling_flax_t5.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -570,6 +570,24 @@ class FlaxRobertaPreTrainedModel:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxT5ForConditionalGeneration:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxT5Model:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxViTForImageClassification:
|
class FlaxViTForImageClassification:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FlaxBartModelTester(unittest.TestCase):
|
class FlaxBartModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
|
|||||||
513
tests/test_modeling_flax_t5.py
Normal file
513
tests/test_modeling_flax_t5.py
Normal file
File diff suppressed because one or more lines are too long
@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
|||||||
def tokenizer(self):
|
def tokenizer(self):
|
||||||
return T5Tokenizer.from_pretrained("t5-base")
|
return T5Tokenizer.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_small_generation(self):
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
|
||||||
|
model.config.max_length = 8
|
||||||
|
model.config.num_beams = 1
|
||||||
|
model.config.do_sample = False
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
|
||||||
|
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
sequences = model.generate(input_ids)
|
||||||
|
|
||||||
|
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||||
|
self.assertTrue(output_str == "Hello there!")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_integration_test(self):
|
def test_small_integration_test(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user