Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -9,6 +9,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from transformers import (
|
||||
@@ -35,7 +36,11 @@ class Split(Enum):
|
||||
|
||||
|
||||
def get_tfds(
|
||||
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
|
||||
task_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_length: Optional[int] = None,
|
||||
mode: Split = Split.train,
|
||||
data_dir: str = None,
|
||||
):
|
||||
if task_name == "mnli-mm" and mode == Split.dev:
|
||||
tfds_name = "mnli_mismatched"
|
||||
@@ -50,9 +55,11 @@ def get_tfds(
|
||||
else:
|
||||
tfds_name = task_name
|
||||
|
||||
ds = tfds.load("glue/" + tfds_name, split=mode.value)
|
||||
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
|
||||
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
|
||||
|
||||
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||
return ds
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
|
||||
"""
|
||||
|
||||
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
||||
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@@ -171,13 +179,22 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
|
||||
get_tfds(
|
||||
task_name=data_args.task_name,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
data_dir=data_args.data_dir,
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
get_tfds(
|
||||
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
|
||||
task_name=data_args.task_name,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
mode=Split.dev,
|
||||
data_dir=data_args.data_dir,
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
|
||||
Reference in New Issue
Block a user