Fix the TF Trainer gradient accumulation and the TF NER example (#6713)
* Align TF NER example over the PT one * Fix Dataset call * Fix gradient accumulation training * Apply style * Address Sylvain's comments * Address Sylvain's comments * Apply style
This commit is contained in:
@@ -18,6 +18,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from importlib import import_module
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -32,7 +33,7 @@ from transformers import (
|
|||||||
TFTrainer,
|
TFTrainer,
|
||||||
TFTrainingArguments,
|
TFTrainingArguments,
|
||||||
)
|
)
|
||||||
from utils_ner import Split, TFNerDataset, get_labels
|
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -50,6 +51,9 @@ class ModelArguments:
|
|||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
|
task_type: Optional[str] = field(
|
||||||
|
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
|
||||||
|
)
|
||||||
tokenizer_name: Optional[str] = field(
|
tokenizer_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
@@ -102,6 +106,17 @@ def main():
|
|||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
module = import_module("tasks")
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_classification_task_clazz = getattr(module, model_args.task_type)
|
||||||
|
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
|
||||||
|
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
@@ -117,7 +132,7 @@ def main():
|
|||||||
logger.info("Training/evaluation parameters %s", training_args)
|
logger.info("Training/evaluation parameters %s", training_args)
|
||||||
|
|
||||||
# Prepare Token Classification task
|
# Prepare Token Classification task
|
||||||
labels = get_labels(data_args.labels)
|
labels = token_classification_task.get_labels(data_args.labels)
|
||||||
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
||||||
num_labels = len(labels)
|
num_labels = len(labels)
|
||||||
|
|
||||||
@@ -150,7 +165,8 @@ def main():
|
|||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = (
|
train_dataset = (
|
||||||
TFNerDataset(
|
TFTokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@@ -163,7 +179,8 @@ def main():
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
TFNerDataset(
|
TFTokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@@ -233,7 +250,8 @@ def main():
|
|||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
test_dataset = TFNerDataset(
|
test_dataset = TFTokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ if is_torch_available():
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
class TFNerDataset:
|
class TFTokenClassificationDataset:
|
||||||
"""
|
"""
|
||||||
This will be superseded by a framework-agnostic approach
|
This will be superseded by a framework-agnostic approach
|
||||||
soon.
|
soon.
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class TFTokenClassificationLoss:
|
|||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100
|
||||||
# are taken into account as loss
|
# are taken into account as loss
|
||||||
if tf.math.reduce_any(labels == -1).numpy() is True:
|
if tf.math.reduce_any(labels == -1):
|
||||||
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -620,13 +620,22 @@ class TFTrainer:
|
|||||||
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
||||||
else:
|
else:
|
||||||
for _ in tf.range(self.args.gradient_accumulation_steps):
|
for _ in tf.range(self.args.gradient_accumulation_steps):
|
||||||
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
|
reduced_features = {
|
||||||
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
|
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
|
||||||
|
}
|
||||||
|
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
|
||||||
|
|
||||||
self.training_step(reduced_features, reduced_labels)
|
self.training_step(reduced_features, reduced_labels)
|
||||||
|
|
||||||
features = tf.concat(
|
features = {
|
||||||
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
|
k: tf.concat(
|
||||||
|
[ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]], axis=0,
|
||||||
|
)
|
||||||
|
for k, ft in features.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
labels = tf.concat(
|
||||||
|
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
|
||||||
)
|
)
|
||||||
|
|
||||||
gradients = self.gradient_accumulator.gradients
|
gradients = self.gradient_accumulator.gradients
|
||||||
|
|||||||
Reference in New Issue
Block a user