[TF models] Common attributes as per #1721
This commit is contained in:
@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
name='h_._{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
||||
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
|
||||
Reference in New Issue
Block a user