[Flax examples] remove dependancy on pytorch training args (#14636)
* use custom training arguments * update tests
This commit is contained in:
@@ -24,7 +24,8 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
@@ -49,7 +50,6 @@ from transformers import (
|
||||
AutoConfig,
|
||||
FlaxAutoModelForImageClassification,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
@@ -63,6 +63,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||
)
|
||||
overwrite_output_dir: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Overwrite the content of the output directory. "
|
||||
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||
)
|
||||
},
|
||||
)
|
||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||
per_device_train_batch_size: int = field(
|
||||
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||
)
|
||||
per_device_eval_batch_size: int = field(
|
||||
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||
)
|
||||
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||
push_to_hub: bool = field(
|
||||
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||
)
|
||||
hub_model_id: str = field(
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.output_dir is not None:
|
||||
self.output_dir = os.path.expanduser(self.output_dir)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||
the token values by removing their value.
|
||||
"""
|
||||
d = asdict(self)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, Enum):
|
||||
d[k] = v.value
|
||||
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||
d[k] = [x.value for x in v]
|
||||
if k.endswith("_token"):
|
||||
d[k] = f"<{k.upper()}>"
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user