RoBERTa doesn't print a warning when no special tokens are passed.
This commit is contained in:
@@ -65,22 +65,6 @@ class TFRobertaMainLayer(TFBertMainLayer):
|
||||
super(TFRobertaMainLayer, self).__init__(config, **kwargs)
|
||||
self.embeddings = TFRobertaEmbeddings(config, name='embeddings')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
# Check that input_ids starts with control token
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if tf.not_equal(tf.reduce_sum(input_ids[:, 0]), 0):
|
||||
tf.print("A sequence with no special tokens has been passed to the RoBERTa model. "
|
||||
"This model requires special tokens in order to work. "
|
||||
"Please specify add_special_tokens=True in your encoding.")
|
||||
|
||||
return super(TFRobertaMainLayer, self).call(inputs, **kwargs)
|
||||
|
||||
|
||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
|
||||
Reference in New Issue
Block a user