Warning for RoBERTa sequences encoded without special tokens.
This commit is contained in:
@@ -165,6 +165,13 @@ class RobertaModel(BertModel):
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
||||
if input_ids[:, 0].sum().item() != 0:
|
||||
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
|
||||
"This model requires special tokens in order to work. "
|
||||
"Please specify add_special_tokens=True in your encoding.")
|
||||
return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask)
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
|
||||
Reference in New Issue
Block a user