[TF models] Common attributes as per #1721
This commit is contained in:
@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
|
||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.input_embeddings
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user