All TODOs to be checked by Thom have been added.
This commit is contained in:
@@ -496,12 +496,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
"""
|
||||
Update input embeddings with new embedding matrice if needed
|
||||
|
||||
TODO
|
||||
|
||||
Args:
|
||||
num_special_tokens:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
|
||||
Returns:
|
||||
TODO Lysandre filled Args
|
||||
|
||||
"""
|
||||
if num_special_tokens is None or self.config.n_special == num_special_tokens:
|
||||
@@ -665,7 +663,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
"""
|
||||
Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings
|
||||
TODO
|
||||
|
||||
Args:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
|
||||
Defaults to True.
|
||||
|
||||
TODO Lysandre filled Args
|
||||
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
@@ -775,9 +779,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
""" Update input and output embeddings with new embedding matrice
|
||||
Make sure we are sharing the embeddings
|
||||
TODO
|
||||
""" Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings.
|
||||
|
||||
Args:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
|
||||
Defaults to True.
|
||||
|
||||
TODO Lysandre filled Args
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
|
||||
Reference in New Issue
Block a user