Merge pull request #337 from CatalinVoss/patch-2
Allow tokenization of sequences > 512 for caching
This commit is contained in:
@@ -163,7 +163,7 @@ def main():
|
|||||||
datasets = (train_dataset, eval_dataset)
|
datasets = (train_dataset, eval_dataset)
|
||||||
encoded_datasets = tokenize_and_encode(datasets)
|
encoded_datasets = tokenize_and_encode(datasets)
|
||||||
|
|
||||||
# Compute the mex input length for the Transformer
|
# Compute the max input length for the Transformer
|
||||||
max_length = model.config.n_positions // 2 - 2
|
max_length = model.config.n_positions // 2 - 2
|
||||||
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
|
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
|
||||||
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
|
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class BertTokenizer(object):
|
|||||||
for token in tokens:
|
for token in tokens:
|
||||||
ids.append(self.vocab[token])
|
ids.append(self.vocab[token])
|
||||||
if len(ids) > self.max_len:
|
if len(ids) > self.max_len:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
" sequence length for this BERT model ({} > {}). Running this"
|
" sequence length for this BERT model ({} > {}). Running this"
|
||||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ class GPT2Tokenizer(object):
|
|||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||||
if len(bpe_tokens) > self.max_len:
|
if len(bpe_tokens) > self.max_len:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
|
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
|
||||||
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
|
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class OpenAIGPTTokenizer(object):
|
|||||||
else:
|
else:
|
||||||
ids.append(self.encoder.get(token, 0))
|
ids.append(self.encoder.get(token, 0))
|
||||||
if len(ids) > self.max_len:
|
if len(ids) > self.max_len:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"Token indices sequence length is longer than the specified maximum "
|
"Token indices sequence length is longer than the specified maximum "
|
||||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||||
|
|||||||
Reference in New Issue
Block a user