From 8308170156bdf41134fd0a8027f63f57f9e6a8d6 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 15 Aug 2019 10:29:04 -0400 Subject: [PATCH] Warning for RoBERTa sequences encoded without special tokens. --- pytorch_transformers/modeling_roberta.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index ebf701ead6..adb04b4b3a 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -165,6 +165,13 @@ class RobertaModel(BertModel): self.embeddings = RobertaEmbeddings(config) self.apply(self.init_weights) + def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): + if input_ids[:, 0].sum().item() != 0: + logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. " + "This model requires special tokens in order to work. " + "Please specify add_special_tokens=True in your encoding.") + return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask) + @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)