Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Examples
|
||||
|
||||
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
|
||||
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
|
||||
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
|
||||
|
||||
Here is the list of all our examples:
|
||||
- **grouped by task** (all official examples work for multiple models)
|
||||
|
||||
@@ -204,6 +204,8 @@ if is_tf_available():
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -21,6 +21,8 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@@ -68,6 +70,7 @@ class DataTrainingArguments:
|
||||
data_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
||||
)
|
||||
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@@ -170,7 +173,7 @@ def main():
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
if not data_args.data_dir:
|
||||
if data_args.use_tfds:
|
||||
if data_args.version_2_with_negative:
|
||||
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
|
||||
|
||||
@@ -179,7 +182,7 @@ def main():
|
||||
except ImportError:
|
||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
|
||||
tfds_examples = tfds.load("squad")
|
||||
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
|
||||
train_examples = (
|
||||
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
|
||||
if training_args.do_train
|
||||
@@ -209,6 +212,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
|
||||
|
||||
eval_dataset = (
|
||||
squad_convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
@@ -223,6 +228,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from transformers import (
|
||||
@@ -35,7 +36,11 @@ class Split(Enum):
|
||||
|
||||
|
||||
def get_tfds(
|
||||
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
|
||||
task_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_length: Optional[int] = None,
|
||||
mode: Split = Split.train,
|
||||
data_dir: str = None,
|
||||
):
|
||||
if task_name == "mnli-mm" and mode == Split.dev:
|
||||
tfds_name = "mnli_mismatched"
|
||||
@@ -50,9 +55,11 @@ def get_tfds(
|
||||
else:
|
||||
tfds_name = task_name
|
||||
|
||||
ds = tfds.load("glue/" + tfds_name, split=mode.value)
|
||||
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
|
||||
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
|
||||
|
||||
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||
return ds
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
|
||||
"""
|
||||
|
||||
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
||||
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@@ -171,13 +179,22 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
|
||||
get_tfds(
|
||||
task_name=data_args.task_name,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
data_dir=data_args.data_dir,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
get_tfds(
|
||||
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
|
||||
task_name=data_args.task_name,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
mode=Split.dev,
|
||||
data_dir=data_args.data_dir,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -185,11 +184,6 @@ def main():
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(seq_len):
|
||||
if label_ids[i, j] == -1:
|
||||
label_ids[i, j] = -100
|
||||
warnings.warn(
|
||||
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
|
||||
)
|
||||
if label_ids[i, j] != -100:
|
||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
@@ -146,7 +146,7 @@ if is_tf_available():
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
pad_token_label_id: int = -1
|
||||
pad_token_label_id: int = -100
|
||||
# Use cross entropy ignore_index as padding label id so that only
|
||||
# real label ids contribute to the loss later.
|
||||
|
||||
@@ -221,6 +221,8 @@ if is_tf_available():
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user