Use roberta model and update doc strings

This commit is contained in:
Matt Maybeno
2019-10-23 21:31:28 -07:00
committed by Julien Chaumond
parent 66085a1321
commit b92d68421d
2 changed files with 6 additions and 2 deletions

View File

@@ -478,12 +478,16 @@ class RobertaForTokenClassification(BertPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForTokenClassification.from_pretrained('roberta-base') model = RobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config): def __init__(self, config):
super(RobertaForTokenClassification, self).__init__(config) super(RobertaForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels

View File

@@ -396,7 +396,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForTokenClassification.from_pretrained('roberta-base') model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids) outputs = model(input_ids)
scores = outputs[0] scores = outputs[0]