Merge pull request #1792 from stefan-it/distilbert-for-token-classification
DistilBERT for token classification
This commit is contained in:
@@ -554,6 +554,16 @@ On the test dataset the following results could be achieved:
|
||||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
| Model | F-Score Dev | F-Score Test
|
||||
| --------------------------------- | ------- | --------
|
||||
| `bert-large-cased` | 95.59 | 91.70
|
||||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
## Abstractive summarization
|
||||
|
||||
Based on the script
|
||||
|
||||
@@ -36,16 +36,18 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)),
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer)
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
@@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
||||
# XLM and RoBERTa don"t use segment_ids
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
@@ -210,9 +213,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
|
||||
# XLM and RoBERTa don"t use segment_ids
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -524,3 +527,4 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user