From b92d68421dee75c3a078b26b78a05bd59007d855 Mon Sep 17 00:00:00 2001 From: Matt Maybeno Date: Wed, 23 Oct 2019 21:31:28 -0700 Subject: [PATCH] Use roberta model and update doc strings --- transformers/modeling_roberta.py | 6 +++++- transformers/modeling_tf_roberta.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_roberta.py b/transformers/modeling_roberta.py index 6b8d381579..9d16c87888 100644 --- a/transformers/modeling_roberta.py +++ b/transformers/modeling_roberta.py @@ -478,12 +478,16 @@ class RobertaForTokenClassification(BertPreTrainedModel): tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = RobertaForTokenClassification.from_pretrained('roberta-base') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 outputs = model(input_ids, labels=labels) loss, scores = outputs[:2] """ + config_class = RobertaConfig + pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "roberta" + def __init__(self, config): super(RobertaForTokenClassification, self).__init__(config) self.num_labels = config.num_labels diff --git a/transformers/modeling_tf_roberta.py b/transformers/modeling_tf_roberta.py index 13a0522211..a239bc642b 100644 --- a/transformers/modeling_tf_roberta.py +++ b/transformers/modeling_tf_roberta.py @@ -396,7 +396,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel): tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = TFRobertaForTokenClassification.from_pretrained('roberta-base') - input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1 outputs = model(input_ids) scores = outputs[0]