From pretrained correct initialization. Unknown token handling for gpt2.
This commit is contained in:
@@ -423,7 +423,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
num_special_tokens = kwargs.pop('num_special_tokens', None)
|
num_special_tokens = kwargs.pop('num_special_tokens', None)
|
||||||
|
|
||||||
model = super().from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
model = super(GPT2PreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
# Add additional embeddings for special tokens if needed
|
# Add additional embeddings for special tokens if needed
|
||||||
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
||||||
|
|||||||
@@ -431,7 +431,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||||||
num_special_tokens = kwargs.get('num_special_tokens', None)
|
num_special_tokens = kwargs.get('num_special_tokens', None)
|
||||||
kwargs.pop('num_special_tokens', None)
|
kwargs.pop('num_special_tokens', None)
|
||||||
|
|
||||||
model = super(PreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, pretrained_model_name_or_path, *inputs, **kwargs)
|
model = super(OpenAIGPTPreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
# Add additional embeddings for special tokens if needed
|
# Add additional embeddings for special tokens if needed
|
||||||
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
||||||
|
|||||||
@@ -177,11 +177,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
return self.encoder.get(token)
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||||
return self.decoder.get(index, self.unk_token)
|
return self.decoder.get(index)
|
||||||
|
|
||||||
def _convert_ids_to_string(self, tokens_ids):
|
def _convert_ids_to_string(self, tokens_ids):
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
|
|||||||
Reference in New Issue
Block a user