[TF models] Common attributes as per #1721
This commit is contained in:
@@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
|
||||
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
|
||||
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
@@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.roberta(inputs, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user