Rework TF trainer (#6038)

* Fully rework training/prediction loops

* fix method name

* Fix variable name

* Fix property name

* Fix scope

* Fix method name

* Fix tuple index

* Fix tuple index

* Fix indentation

* Fix variable name

* fix eval before log

* Add drop remainder for test dataset

* Fix step number + fix logging datetime

* fix eval loss value

* use global step instead of step + fix logging at step 0

* Fix logging datetime

* Fix global_step usage

* Fix breaking loop + logging datetime

* Fix step in prediction loop

* Fix step breaking

* Fix train/test loops

* Force TF at least 2.2 for the trainer

* Use assert_cardinality to facilitate the dataset size computation

* Log steps per epoch

* Make tfds compliant with TPU

* Make tfds compliant with TPU

* Use TF dataset enumerate instead of the Python one

* revert previous commit

* Fix data_dir

* Apply style

* rebase on master

* Address Sylvain's comments

* Address Sylvain's and Lysandre comments

* Trigger CI

* Remove unused import
This commit is contained in:
Julien Plu
2020-07-29 20:32:01 +02:00
committed by GitHub
parent 3f94170a10
commit 54f9fbeff8
9 changed files with 247 additions and 214 deletions

View File

@@ -21,6 +21,8 @@ import os
from dataclasses import dataclass, field
from typing import Optional
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -68,6 +70,7 @@ class DataTrainingArguments:
data_dir: Optional[str] = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
max_seq_length: int = field(
default=128,
metadata={
@@ -170,7 +173,7 @@ def main():
)
# Get datasets
if not data_args.data_dir:
if data_args.use_tfds:
if data_args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
@@ -179,7 +182,7 @@ def main():
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
tfds_examples = tfds.load("squad")
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
train_examples = (
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
if training_args.do_train
@@ -209,6 +212,8 @@ def main():
else None
)
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
eval_dataset = (
squad_convert_examples_to_features(
examples=eval_examples,
@@ -223,6 +228,8 @@ def main():
else None
)
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
# Initialize our Trainer
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)