Fix the TF Trainer gradient accumulation and the TF NER example (#6713)

* Align TF NER example over the PT one

* Fix Dataset call

* Fix gradient accumulation training

* Apply style

* Address Sylvain's comments

* Address Sylvain's comments

* Apply style
This commit is contained in:
Julien Plu
2020-08-27 14:45:34 +02:00
committed by GitHub
parent 41aa2b4ef1
commit 6f289dc97a
4 changed files with 38 additions and 11 deletions

View File

@@ -18,6 +18,7 @@
import logging
import os
from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple
import numpy as np
@@ -32,7 +33,7 @@ from transformers import (
TFTrainer,
TFTrainingArguments,
)
from utils_ner import Split, TFNerDataset, get_labels
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__)
@@ -50,6 +51,9 @@ class ModelArguments:
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
@@ -102,6 +106,17 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -117,7 +132,7 @@ def main():
logger.info("Training/evaluation parameters %s", training_args)
# Prepare Token Classification task
labels = get_labels(data_args.labels)
labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
@@ -150,7 +165,8 @@ def main():
# Get datasets
train_dataset = (
TFNerDataset(
TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@@ -163,7 +179,8 @@ def main():
else None
)
eval_dataset = (
TFNerDataset(
TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@@ -233,7 +250,8 @@ def main():
# Predict
if training_args.do_predict:
test_dataset = TFNerDataset(
test_dataset = TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,