@@ -23,7 +23,7 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_bert import BertConfig
|
||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
class TFBertMainLayer(TFMainLayer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(config, **kwargs)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||
|
||||
Reference in New Issue
Block a user