Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
This commit is contained in:
Julien Plu
2020-07-29 20:32:01 +02:00
committed by GitHub
parent 3f94170a10
commit 54f9fbeff8
9 changed files with 247 additions and 214 deletions

View File

@@ -9,6 +9,7 @@ from enum import Enum
from typing import Dict, Optional
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import (
@@ -35,7 +36,11 @@ class Split(Enum):
def get_tfds(
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
task_name: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: Optional[int] = None,
mode: Split = Split.train,
data_dir: str = None,
):
if task_name == "mnli-mm" and mode == Split.dev:
tfds_name = "mnli_mismatched"
@@ -50,9 +55,11 @@ def get_tfds(
else:
tfds_name = task_name
ds = tfds.load("glue/" + tfds_name, split=mode.value)
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
return ds
logger = logging.getLogger(__name__)
@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
"""
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
max_seq_length: int = field(
default=128,
metadata={
@@ -171,13 +179,22 @@ def main():
# Get datasets
train_dataset = (
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
get_tfds(
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
data_dir=data_args.data_dir,
)
if training_args.do_train
else None
)
eval_dataset = (
get_tfds(
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
task_name=data_args.task_name,
tokenizer=tokenizer,
max_seq_length=data_args.max_seq_length,
mode=Split.dev,
data_dir=data_args.data_dir,
)
if training_args.do_eval
else None