All TODOs to be checked by Thom have been added.

This commit is contained in:
LysandreJik
2019-07-10 15:16:40 -04:00
parent f773faa258
commit 5288913bdd
3 changed files with 72 additions and 32 deletions

View File

@@ -496,12 +496,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
"""
Update input embeddings with new embedding matrice if needed
TODO
Args:
num_special_tokens:
num_special_tokens: Special tokens to be added to the embedding matrix
Returns:
TODO Lysandre filled Args
"""
if num_special_tokens is None or self.config.n_special == num_special_tokens:
@@ -665,7 +663,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
"""
Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings
TODO
Args:
num_special_tokens: Special tokens to be added to the embedding matrix
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
Defaults to True.
TODO Lysandre filled Args
"""
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
@@ -775,9 +779,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
TODO
""" Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings.
Args:
num_special_tokens: Special tokens to be added to the embedding matrix
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
Defaults to True.
TODO Lysandre filled Args
"""
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens)