@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -478,7 +478,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
class TFAlbertMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
class TFBertMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
class TFCTRLMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_past = config.output_past
|
self.output_past = config.output_past
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -397,9 +397,9 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
class TFDistilBertMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from .configuration_gpt2 import GPT2Config
|
|||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
|
TFMainLayer,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
@@ -196,9 +197,9 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, present, (attentions)
|
return outputs # x, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
class TFGPT2MainLayer(TFMainLayer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from .configuration_openai import OpenAIGPTConfig
|
|||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
|
TFMainLayer,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
@@ -197,7 +198,7 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, (attentions)
|
return outputs # x, (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
class TFOpenAIGPTMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
|
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -359,9 +359,9 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||||||
# The full model without a specific pretrained or finetuning head is
|
# The full model without a specific pretrained or finetuning head is
|
||||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||||
####################################################
|
####################################################
|
||||||
class TFT5MainLayer(tf.keras.layers.Layer):
|
class TFT5MainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -378,9 +378,9 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
class TFTransfoXLMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,23 @@ class TFModelUtilsMixin:
|
|||||||
return self.count_params()
|
return self.count_params()
|
||||||
|
|
||||||
|
|
||||||
|
class TFMainLayer(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config = PretrainedConfig.from_dict(config)
|
||||||
|
self._transformers_config = config
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
cfg = super().get_config()
|
||||||
|
cfg["config"] = self._transformers_config.to_dict()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,14 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMainLayer,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFSequenceSummary,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
get_initializer,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -196,9 +203,9 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TFXLMMainLayer(tf.keras.layers.Layer):
|
class TFXLMMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,14 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_xlnet import XLNetConfig
|
from .configuration_xlnet import XLNetConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMainLayer,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFSequenceSummary,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
get_initializer,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -342,9 +349,9 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
class TFXLNetMainLayer(TFMainLayer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(config, **kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.output_past = config.output_past
|
self.output_past = config.output_past
|
||||||
|
|||||||
Reference in New Issue
Block a user