Warning for RoBERTa sequences encoded without special tokens.

This commit is contained in:
LysandreJik
2019-08-15 10:29:04 -04:00
parent 572dcfd1db
commit 8308170156

View File

@@ -165,6 +165,13 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config) self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if input_ids[:, 0].sum().item() != 0:
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.")
return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask)
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)