Fix #3305: run_ner only possible on ModelForTokenClassification models
This commit is contained in:
@@ -31,7 +31,6 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -39,7 +38,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
from transformers.modeling_auto import MODEL_MAPPING
|
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
|
|
||||||
|
|
||||||
@@ -51,8 +50,9 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||||
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
||||||
|
|
||||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
|
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
|
|||||||
Reference in New Issue
Block a user