Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -21,6 +21,8 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@@ -68,6 +70,7 @@ class DataTrainingArguments:
|
||||
data_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
||||
)
|
||||
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@@ -170,7 +173,7 @@ def main():
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
if not data_args.data_dir:
|
||||
if data_args.use_tfds:
|
||||
if data_args.version_2_with_negative:
|
||||
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
|
||||
|
||||
@@ -179,7 +182,7 @@ def main():
|
||||
except ImportError:
|
||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
|
||||
tfds_examples = tfds.load("squad")
|
||||
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
|
||||
train_examples = (
|
||||
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
|
||||
if training_args.do_train
|
||||
@@ -209,6 +212,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
|
||||
|
||||
eval_dataset = (
|
||||
squad_convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
@@ -223,6 +228,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user