[FLAX] glue training example refactor (#13815)
* refactor run_flax_glue.py * updated readme * rm unused import and args typo fix * refactor * make consistent arg name across task * has_tensorboard check * argparse -> argument dataclasses * refactor according to review * fix
This commit is contained in:
@@ -85,10 +85,10 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--per_device_train_batch_size=2
|
--per_device_train_batch_size=2
|
||||||
--per_device_eval_batch_size=1
|
--per_device_eval_batch_size=1
|
||||||
--learning_rate=1e-4
|
--learning_rate=1e-4
|
||||||
--max_train_steps=10
|
--eval_steps=2
|
||||||
--num_warmup_steps=2
|
--warmup_steps=2
|
||||||
--seed=42
|
--seed=42
|
||||||
--max_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
|
|||||||
@@ -33,15 +33,16 @@ export TASK_NAME=mrpc
|
|||||||
python run_flax_glue.py \
|
python run_flax_glue.py \
|
||||||
--model_name_or_path bert-base-cased \
|
--model_name_or_path bert-base-cased \
|
||||||
--task_name ${TASK_NAME} \
|
--task_name ${TASK_NAME} \
|
||||||
--max_length 128 \
|
--max_seq_length 128 \
|
||||||
--learning_rate 2e-5 \
|
--learning_rate 2e-5 \
|
||||||
--num_train_epochs 3 \
|
--num_train_epochs 3 \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
|
--eval_steps 100 \
|
||||||
--output_dir ./$TASK_NAME/ \
|
--output_dir ./$TASK_NAME/ \
|
||||||
--push_to_hub
|
--push_to_hub
|
||||||
```
|
```
|
||||||
|
|
||||||
where task name can be one of cola, mnli, mnli-mm, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
|
||||||
|
|
||||||
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
Using the command above, the script will train for 3 epochs and run eval after each epoch.
|
||||||
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
|
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
|
||||||
|
|||||||
@@ -14,18 +14,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -40,13 +43,18 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FlaxAutoModelForSequenceClassification,
|
FlaxAutoModelForSequenceClassification,
|
||||||
|
HfArgumentParser,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
|
TrainingArguments,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import get_full_repo_name
|
from transformers.file_utils import get_full_repo_name
|
||||||
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
check_min_version("4.16.0.dev0")
|
||||||
|
|
||||||
Array = Any
|
Array = Any
|
||||||
Dataset = datasets.arrow_dataset.Dataset
|
Dataset = datasets.arrow_dataset.Dataset
|
||||||
@@ -66,101 +74,118 @@ task_to_keys = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
@dataclass
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
class ModelArguments:
|
||||||
parser.add_argument(
|
"""
|
||||||
"--task_name",
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
type=str,
|
"""
|
||||||
default=None,
|
|
||||||
help="The name of the glue task to train on.",
|
|
||||||
choices=list(task_to_keys.keys()),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_length",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help=(
|
|
||||||
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
|
||||||
" sequences shorter will be padded."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_slow_tokenizer",
|
|
||||||
action="store_true",
|
|
||||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--per_device_train_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Batch size (per device) for the training dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--per_device_eval_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Batch size (per device) for the evaluation dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate",
|
|
||||||
type=float,
|
|
||||||
default=5e-5,
|
|
||||||
help="Initial learning rate (after the potential warmup period) to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
|
||||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_train_steps",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
|
||||||
)
|
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
|
||||||
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--push_to_hub",
|
|
||||||
action="store_true",
|
|
||||||
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
|
||||||
)
|
|
||||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Sanity checks
|
model_name_or_path: str = field(
|
||||||
if args.task_name is None and args.train_file is None and args.validation_file is None:
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
raise ValueError("Need either a task name or a training/validation file.")
|
)
|
||||||
|
config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
tokenizer_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
use_slow_tokenizer: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."},
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||||
|
)
|
||||||
|
model_revision: str = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
|
)
|
||||||
|
use_auth_token: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
|
"with private models)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataTrainingArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"}
|
||||||
|
)
|
||||||
|
dataset_config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||||
|
)
|
||||||
|
train_file: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
|
||||||
|
)
|
||||||
|
validation_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
|
||||||
|
)
|
||||||
|
test_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
|
||||||
|
)
|
||||||
|
text_column_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
|
||||||
|
)
|
||||||
|
label_column_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
|
||||||
|
)
|
||||||
|
overwrite_cache: bool = field(
|
||||||
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
|
)
|
||||||
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The maximum total input sequence length after tokenization. If set, sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_train_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_eval_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_predict_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.task_name is None and self.train_file is None and self.validation_file is None:
|
||||||
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
else:
|
else:
|
||||||
if args.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = args.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||||
if args.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = args.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name
|
||||||
if args.push_to_hub:
|
|
||||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
|
||||||
|
|
||||||
if args.output_dir is not None:
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def create_train_state(
|
def create_train_state(
|
||||||
@@ -249,7 +274,7 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
|||||||
|
|
||||||
for perm in perms:
|
for perm in perms:
|
||||||
batch = dataset[perm]
|
batch = dataset[perm]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
batch = shard(batch)
|
batch = shard(batch)
|
||||||
|
|
||||||
yield batch
|
yield batch
|
||||||
@@ -259,14 +284,20 @@ def glue_eval_data_collator(dataset: Dataset, batch_size: int):
|
|||||||
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
|
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
|
||||||
for i in range(len(dataset) // batch_size):
|
for i in range(len(dataset) // batch_size):
|
||||||
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
batch = shard(batch)
|
batch = shard(batch)
|
||||||
|
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
# Make one log on every process with the configuration for debugging.
|
# Make one log on every process with the configuration for debugging.
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -284,12 +315,14 @@ def main():
|
|||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
# Handle the repository creation
|
# Handle the repository creation
|
||||||
if args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
if args.hub_model_id is None:
|
if training_args.hub_model_id is None:
|
||||||
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
|
repo_name = get_full_repo_name(
|
||||||
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
repo_name = args.hub_model_id
|
repo_name = training_args.hub_model_id
|
||||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||||
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
@@ -303,24 +336,24 @@ def main():
|
|||||||
|
|
||||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||||
# download the dataset.
|
# download the dataset.
|
||||||
if args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
raw_datasets = load_dataset("glue", args.task_name)
|
raw_datasets = load_dataset("glue", data_args.task_name)
|
||||||
else:
|
else:
|
||||||
# Loading the dataset from local csv or json file.
|
# Loading the dataset from local csv or json file.
|
||||||
data_files = {}
|
data_files = {}
|
||||||
if args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
if args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
|
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||||
# See more about loading any type of standard or custom dataset at
|
# See more about loading any type of standard or custom dataset at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
if args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
is_regression = args.task_name == "stsb"
|
is_regression = data_args.task_name == "stsb"
|
||||||
if not is_regression:
|
if not is_regression:
|
||||||
label_list = raw_datasets["train"].features["label"].names
|
label_list = raw_datasets["train"].features["label"].names
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
@@ -339,13 +372,17 @@ def main():
|
|||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
config = AutoConfig.from_pretrained(
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name
|
||||||
model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer
|
||||||
|
)
|
||||||
|
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
|
||||||
|
|
||||||
# Preprocessing the datasets
|
# Preprocessing the datasets
|
||||||
if args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
sentence1_key, sentence2_key = task_to_keys[args.task_name]
|
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
|
||||||
else:
|
else:
|
||||||
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
|
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
|
||||||
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
|
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
|
||||||
@@ -361,7 +398,7 @@ def main():
|
|||||||
label_to_id = None
|
label_to_id = None
|
||||||
if (
|
if (
|
||||||
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
|
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
|
||||||
and args.task_name is not None
|
and data_args.task_name is not None
|
||||||
and not is_regression
|
and not is_regression
|
||||||
):
|
):
|
||||||
# Some have all caps in their config, some don't.
|
# Some have all caps in their config, some don't.
|
||||||
@@ -378,7 +415,7 @@ def main():
|
|||||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||||
"\nIgnoring the model labels as a result.",
|
"\nIgnoring the model labels as a result.",
|
||||||
)
|
)
|
||||||
elif args.task_name is None:
|
elif data_args.task_name is None:
|
||||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
@@ -386,7 +423,7 @@ def main():
|
|||||||
texts = (
|
texts = (
|
||||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||||
)
|
)
|
||||||
result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)
|
result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True)
|
||||||
|
|
||||||
if "label" in examples:
|
if "label" in examples:
|
||||||
if label_to_id is not None:
|
if label_to_id is not None:
|
||||||
@@ -402,7 +439,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = processed_datasets["train"]
|
train_dataset = processed_datasets["train"]
|
||||||
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
|
eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
@@ -414,8 +451,8 @@ def main():
|
|||||||
try:
|
try:
|
||||||
from flax.metrics.tensorboard import SummaryWriter
|
from flax.metrics.tensorboard import SummaryWriter
|
||||||
|
|
||||||
summary_writer = SummaryWriter(args.output_dir)
|
summary_writer = SummaryWriter(training_args.output_dir)
|
||||||
summary_writer.hparams(vars(args))
|
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
|
||||||
except ImportError as ie:
|
except ImportError as ie:
|
||||||
has_tensorboard = False
|
has_tensorboard = False
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -427,7 +464,7 @@ def main():
|
|||||||
"Please run pip install tensorboard to enable."
|
"Please run pip install tensorboard to enable."
|
||||||
)
|
)
|
||||||
|
|
||||||
def write_metric(train_metrics, eval_metrics, train_time, step):
|
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||||
summary_writer.scalar("train_time", train_time, step)
|
summary_writer.scalar("train_time", train_time, step)
|
||||||
|
|
||||||
train_metrics = get_metrics(train_metrics)
|
train_metrics = get_metrics(train_metrics)
|
||||||
@@ -436,22 +473,27 @@ def main():
|
|||||||
for i, val in enumerate(vals):
|
for i, val in enumerate(vals):
|
||||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||||
|
|
||||||
|
def write_eval_metric(summary_writer, eval_metrics, step):
|
||||||
for metric_name, value in eval_metrics.items():
|
for metric_name, value in eval_metrics.items():
|
||||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
num_epochs = int(args.num_train_epochs)
|
num_epochs = int(training_args.num_train_epochs)
|
||||||
rng = jax.random.PRNGKey(args.seed)
|
rng = jax.random.PRNGKey(training_args.seed)
|
||||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||||
|
|
||||||
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
|
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
|
||||||
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
|
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
|
||||||
|
|
||||||
learning_rate_fn = create_learning_rate_fn(
|
learning_rate_fn = create_learning_rate_fn(
|
||||||
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
|
len(train_dataset),
|
||||||
|
train_batch_size,
|
||||||
|
training_args.num_train_epochs,
|
||||||
|
training_args.warmup_steps,
|
||||||
|
training_args.learning_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
state = create_train_state(
|
state = create_train_state(
|
||||||
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
|
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay
|
||||||
)
|
)
|
||||||
|
|
||||||
# define step functions
|
# define step functions
|
||||||
@@ -482,8 +524,8 @@ def main():
|
|||||||
|
|
||||||
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
||||||
|
|
||||||
if args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
metric = load_metric("glue", args.task_name)
|
metric = load_metric("glue", data_args.task_name)
|
||||||
else:
|
else:
|
||||||
metric = load_metric("accuracy")
|
metric = load_metric("accuracy")
|
||||||
|
|
||||||
@@ -493,25 +535,56 @@ def main():
|
|||||||
# make sure weights are replicated on each device
|
# make sure weights are replicated on each device
|
||||||
state = replicate(state)
|
state = replicate(state)
|
||||||
|
|
||||||
for epoch in range(1, num_epochs + 1):
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||||
logger.info(f"Epoch {epoch}")
|
total_steps = steps_per_epoch * num_epochs
|
||||||
logger.info(" Training...")
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
|
||||||
|
for epoch in epochs:
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
|
# Create sampling rng
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
|
||||||
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
for step, batch in enumerate(
|
||||||
train_metrics.append(metrics)
|
tqdm(
|
||||||
|
train_loader,
|
||||||
|
total=steps_per_epoch,
|
||||||
|
desc="Training...",
|
||||||
|
position=1,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
||||||
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
cur_step = (epoch * steps_per_epoch) + (step + 1)
|
||||||
|
|
||||||
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||||
|
# Save metrics
|
||||||
|
train_metric = unreplicate(train_metric)
|
||||||
train_time += time.time() - train_start
|
train_time += time.time() - train_start
|
||||||
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
logger.info(" Evaluating...")
|
epochs.write(
|
||||||
|
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
train_metrics = []
|
||||||
|
|
||||||
|
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
||||||
|
|
||||||
|
eval_metrics = {}
|
||||||
# evaluate
|
# evaluate
|
||||||
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
|
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
||||||
|
for batch in tqdm(
|
||||||
|
eval_loader,
|
||||||
|
total=len(eval_dataset) // eval_batch_size,
|
||||||
|
desc="Evaluating ...",
|
||||||
|
position=2,
|
||||||
|
):
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
predictions = p_eval_step(state, batch)
|
predictions = p_eval_step(state, batch)
|
||||||
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
|
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
|
||||||
@@ -523,33 +596,33 @@ def main():
|
|||||||
if num_leftover_samples > 0 and jax.process_index() == 0:
|
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||||
# take leftover samples
|
# take leftover samples
|
||||||
batch = eval_dataset[-num_leftover_samples:]
|
batch = eval_dataset[-num_leftover_samples:]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: np.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
predictions = eval_step(unreplicate(state), batch)
|
predictions = eval_step(unreplicate(state), batch)
|
||||||
metric.add_batch(predictions=predictions, references=labels)
|
metric.add_batch(predictions=predictions, references=labels)
|
||||||
|
|
||||||
eval_metric = metric.compute()
|
eval_metric = metric.compute()
|
||||||
logger.info(f" Done! Eval metrics: {eval_metric}")
|
|
||||||
|
|
||||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
|
||||||
|
|
||||||
# Save metrics
|
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||||
|
|
||||||
|
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
params = jax.device_get(unreplicate(state.params))
|
||||||
model.save_pretrained(args.output_dir, params=params)
|
model.save_pretrained(training_args.output_dir, params=params)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
if args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
||||||
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
|
||||||
|
|
||||||
# save the eval metrics in json
|
# save the eval metrics in json
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
|
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
|
||||||
path = os.path.join(args.output_dir, "eval_results.json")
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(eval_metric, f, indent=4, sort_keys=True)
|
json.dump(eval_metric, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user