diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py
index c4148e283c..38423de14b 100644
--- a/pytorch_transformers/__init__.py
+++ b/pytorch_transformers/__init__.py
@@ -7,7 +7,6 @@ from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
-from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization)
from .tokenization_utils import (PreTrainedTokenizer)
@@ -39,7 +38,7 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
-from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel,
+from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py
index 43c9362b30..6cd4bc2d35 100644
--- a/pytorch_transformers/modeling_roberta.py
+++ b/pytorch_transformers/modeling_roberta.py
@@ -23,7 +23,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch.nn import CrossEntropyLoss
+from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel,
@@ -144,7 +144,6 @@ class RobertaLMHead(nn.Module):
return x
-
class RobertaForSequenceClassification(BertPreTrainedModel):
"""
Roberta Model with a classifier head on top.
diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py
index 7fa42bfb1c..4ec53a65b0 100644
--- a/pytorch_transformers/tokenization_roberta.py
+++ b/pytorch_transformers/tokenization_roberta.py
@@ -21,18 +21,19 @@ import logging
import re
from io import open
import six
+import os
-from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
+from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
logger = logging.getLogger(__name__)
-VOCAB_FILES_NAMES = {
- 'vocab_file': 'dict.txt',
+DICT_FILES_NAMES = {
+ 'dict_file': 'dict.txt',
}
-PRETRAINED_VOCAB_FILES_MAP = {
- 'vocab_file':
+PRETRAINED_DICT_FILES_MAP = {
+ 'dict_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
@@ -178,89 +179,62 @@ class RobertaTokenizer(PreTrainedTokenizer):
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
"""
- vocab_files_names = VOCAB_FILES_NAMES
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ vocab_files_names = DICT_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
- def __init__(self, vocab_file,
- bos_token="", eos_token="", **kwargs):
- super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, **kwargs)
+ def __init__(self, dict_file, bpe_tokenizer=None, bos_token="", eos_token="", sep_token="", cls_token="",
+ unk_token="", **kwargs):
+ super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token,
+ unk_token=unk_token, **kwargs)
- self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- self.dictionary = Dictionary.load(vocab_file)
+ self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer
+ self.dictionary = Dictionary.load(dict_file)
+
+ @property
+ def vocab_size(self):
+ return len(self.dictionary.indices)
def _tokenize(self, text):
""" Use GPT-2 Tokenizer """
return self.gpt2_tokenizer._tokenize(text)
- def encode(self, text, *args):
- """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
- """
- bpe_sentence = [self.cls_token] + \
- self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text)) + \
- [self.sep_token]
-
- if len(args):
- for additional_sentence in args:
- bpe_sentence += [self.sep_token
- ] + \
- self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(additional_sentence)) + \
- [self.sep_token]
-
- return self.dictionary.encode_line(' '.join([str(token) for token in bpe_sentence]), append_eos=False)
-
- def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
- """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
- with options to remove special tokens and clean up tokenization spaces.
- Handles sentence pairs.
- """
- filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
-
- if any(isinstance(element, list) for element in filtered_tokens):
- texts = []
- for element in filtered_tokens:
- text = self.convert_tokens_to_string(element)
- if clean_up_tokenization_spaces:
- text = clean_up_tokenization(text)
- texts.append(text)
- return texts
- else:
- text = self.convert_tokens_to_string(filtered_tokens)
- if clean_up_tokenization_spaces:
- text = clean_up_tokenization(text)
- return text
-
def _convert_token_to_id(self, token):
- return self.dictionary.index(token)
+ if self.dictionary.index(token) != 3:
+ return self.dictionary.index(token)
+ return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
def _convert_id_to_token(self, index):
symbol = self.dictionary[index]
try:
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
- except:
+ except ValueError:
return symbol
def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
+ def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False):
+ cls = [self._convert_token_to_id(self.cls_token)]
+ tokens = super().convert_tokens_to_ids(tokens)
+ sep = [self._convert_token_to_id(self.sep_token)]
+ return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens
+
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
- # Remove the first and last tokens which are cls and sep tokens
- ids = ids[1:-1]
- # If multi sentence, then split (multi sentence found by looking for two sequential sep tokens)
- ids = [list(map(int, example.split(' '))) for example in ' '.join([str(id) for id in ids]).split(' 2 2 ')]
+ return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1]
- if len(ids) == 1:
- tokens = self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), ids[0])))
- else:
- tokens = []
- for example in ids:
- tokens += [
- self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), example)))]
- return tokens
+ def save_vocabulary(self, save_directory):
+ """Save the tokenizer vocabulary and merge files to a directory."""
+ if not os.path.isdir(save_directory):
+ logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
+ return
+ dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file'])
- def convert_tokens_to_ids(self, tokens):
- tokens = " ".join(str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(tokens))
- bpe_sentence = ' ' + tokens + ' '
- return self.dictionary.encode_line(bpe_sentence, append_eos=False)
+ with open(dict_file, 'w', encoding='utf-8') as f:
+ for i in range(self.dictionary.nspecial, len(self.dictionary.count)):
+ f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n")
+ vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory)
+
+ return vocab_files + (dict_file,)
diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py
index 2e75c83bfb..232ef1c35b 100644
--- a/pytorch_transformers/tokenization_utils.py
+++ b/pytorch_transformers/tokenization_utils.py
@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
"""
raise NotImplementedError
- def convert_tokens_to_ids(self, tokens):
+ def convert_tokens_to_ids(self, tokens, **kwargs):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
@@ -520,12 +520,29 @@ class PreTrainedTokenizer(object):
raise NotImplementedError
- def encode(self, text):
+ def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
"""
- return self.convert_tokens_to_ids(self.tokenize(text))
+
+ if len(text) == 1:
+ return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens)
+
+ if len(text) > 2:
+ logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
+ "initial two.")
+
+ first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])]
+ second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])]
+ sep = [self._convert_token_to_id(self.sep_token)]
+ cls = [self._convert_token_to_id(self.cls_token)]
+ n_sep_token = 2 if double_sep_token else 1
+
+ tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep
+ tokens = (tokens + cls) if cls_token_at_end else (cls + tokens)
+
+ return tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
@@ -560,7 +577,8 @@ class PreTrainedTokenizer(object):
"""
return ' '.join(self.convert_ids_to_tokens(tokens))
- def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False,
+ double_sep_token=False):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
@@ -568,9 +586,21 @@ class PreTrainedTokenizer(object):
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
- if clean_up_tokenization_spaces:
- text = self.clean_up_tokenization(text)
- return text
+
+ if self.sep_token is not None and self.sep_token in text:
+ text = text.replace(self.cls_token, self.sep_token)
+ split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
+ if clean_up_tokenization_spaces:
+ clean_text = [self.clean_up_tokenization(text) for text in split_text]
+ return clean_text
+ else:
+ return split_text
+ else:
+ if clean_up_tokenization_spaces:
+ clean_text = self.clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
@property
def special_tokens_map(self):
@@ -602,7 +632,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
- all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
+ all_ids = list(self._convert_token_to_id(t) for t in all_toks)
return all_ids
@staticmethod