Add Roberta to run_ner.py
This commit is contained in:
committed by
Julien Chaumond
parent
b92d68421d
commit
4e5f88b74f
@@ -35,15 +35,17 @@ from utils_ner import convert_examples_to_features, get_labels, read_examples_fr
|
|||||||
|
|
||||||
from transformers import AdamW, WarmupLinearSchedule
|
from transformers import AdamW, WarmupLinearSchedule
|
||||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||||
|
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, )),
|
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)),
|
||||||
())
|
())
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||||
|
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user