Remove T5 dependency from mT5 model (#20949)
make mt5 independent from t5
This commit is contained in:
@@ -1806,7 +1806,9 @@ else:
|
|||||||
"MPNetPreTrainedModel",
|
"MPNetPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.mt5"].extend(["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"])
|
_import_structure["models.mt5"].extend(
|
||||||
|
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model", "MT5PreTrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.mvp"].extend(
|
_import_structure["models.mvp"].extend(
|
||||||
[
|
[
|
||||||
"MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -4922,7 +4924,7 @@ if TYPE_CHECKING:
|
|||||||
MPNetModel,
|
MPNetModel,
|
||||||
MPNetPreTrainedModel,
|
MPNetPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel
|
||||||
from .models.mvp import (
|
from .models.mvp import (
|
||||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
MvpForCausalLM,
|
MvpForCausalLM,
|
||||||
|
|||||||
@@ -51,7 +51,13 @@ try:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
|
_import_structure["modeling_mt5"] = [
|
||||||
|
"MT5EncoderModel",
|
||||||
|
"MT5ForConditionalGeneration",
|
||||||
|
"MT5Model",
|
||||||
|
"MT5PreTrainedModel",
|
||||||
|
"MT5Stack",
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
@@ -79,7 +85,7 @@ if TYPE_CHECKING:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
|
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel, MT5Stack
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4003,6 +4003,13 @@ class MT5Model(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MT5PreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ PRIVATE_MODELS = [
|
|||||||
"LongT5Stack",
|
"LongT5Stack",
|
||||||
"RealmBertModel",
|
"RealmBertModel",
|
||||||
"T5Stack",
|
"T5Stack",
|
||||||
|
"MT5Stack",
|
||||||
"SwitchTransformersStack",
|
"SwitchTransformersStack",
|
||||||
"TFDPRSpanPredictor",
|
"TFDPRSpanPredictor",
|
||||||
"MaskFormerSwinModel",
|
"MaskFormerSwinModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user