[TF models] Common attributes as per #1721
This commit is contained in:
@@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
self.encoder = TFBertEncoder(config, name='encoder')
|
||||
self.pooler = TFBertPooler(config, name='pooler')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
||||
self.nsp = TFBertNSPHead(config, name='nsp___cls')
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
@@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
||||
self.bert = TFBertMainLayer(config, name='bert')
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user