RoBERTa token classification

[WIP] copy paste bert token classification for roberta
This commit is contained in:
Matt Maybeno
2019-10-23 21:05:13 -07:00
committed by Julien Chaumond
parent 5b6cafb11b
commit 66085a1321
5 changed files with 158 additions and 1 deletions

View File

@@ -24,7 +24,8 @@ from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
from transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM,
RobertaForSequenceClassification, RobertaForTokenClassification)
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
@@ -156,6 +157,22 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_roberta_for_token_classification(self, config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = RobertaForTokenClassification(config=config)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,