[Flax examples] remove dependancy on pytorch training args (#14636)
* use custom training arguments * update tests
This commit is contained in:
@@ -27,7 +27,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
@@ -53,7 +54,6 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FlaxAutoModelForCausalLM,
|
FlaxAutoModelForCausalLM,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
@@ -67,6 +67,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||||
@@ -54,7 +55,6 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
TensorType,
|
TensorType,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
@@ -65,6 +65,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,13 +19,15 @@ Pretraining the library models for T5-like span-masked language modeling on a te
|
|||||||
Here is the full list of checkpoints on the hub that can be pretrained by this script:
|
Here is the full list of checkpoints on the hub that can be pretrained by this script:
|
||||||
https://huggingface.co/models?filter=t5
|
https://huggingface.co/models?filter=t5
|
||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||||
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
@@ -51,7 +53,6 @@ from transformers import (
|
|||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
T5Config,
|
T5Config,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
@@ -63,6 +64,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
@@ -50,7 +51,6 @@ from transformers import (
|
|||||||
FlaxAutoModelForQuestionAnswering,
|
FlaxAutoModelForQuestionAnswering,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import get_full_repo_name
|
from transformers.file_utils import get_full_repo_name
|
||||||
@@ -69,6 +69,69 @@ PRNGKey = Any
|
|||||||
|
|
||||||
|
|
||||||
# region Arguments
|
# region Arguments
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
@@ -51,7 +52,6 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FlaxAutoModelForSeq2SeqLM,
|
FlaxAutoModelForSeq2SeqLM,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
||||||
@@ -74,6 +74,72 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
label_smoothing_factor: float = field(
|
||||||
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||||
|
)
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--test_file tests/fixtures/tests_samples/xsum/sample.json
|
--test_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
--output_dir {tmp_dir}
|
--output_dir {tmp_dir}
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--max_steps=50
|
--num_train_epochs=3
|
||||||
--warmup_steps=8
|
--warmup_steps=8
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
@@ -257,7 +257,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
--output_dir {tmp_dir}
|
--output_dir {tmp_dir}
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--max_steps=10
|
--num_train_epochs=3
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
@@ -44,7 +45,6 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FlaxAutoModelForTokenClassification,
|
FlaxAutoModelForTokenClassification,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import get_full_repo_name
|
from transformers.file_utils import get_full_repo_name
|
||||||
@@ -63,6 +63,68 @@ Dataset = datasets.arrow_dataset.Dataset
|
|||||||
PRNGKey = Any
|
PRNGKey = Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@@ -49,7 +50,6 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
FlaxAutoModelForImageClassification,
|
FlaxAutoModelForImageClassification,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
TrainingArguments,
|
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
@@ -63,6 +63,68 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
|||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments:
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
||||||
|
)
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Overwrite the content of the output directory. "
|
||||||
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||||
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
||||||
|
per_device_train_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||||
|
)
|
||||||
|
per_device_eval_batch_size: int = field(
|
||||||
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||||
|
)
|
||||||
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
||||||
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
||||||
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
||||||
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
||||||
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
||||||
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
||||||
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
||||||
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
|
push_to_hub: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
||||||
|
)
|
||||||
|
hub_model_id: str = field(
|
||||||
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||||
|
)
|
||||||
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
||||||
|
the token values by removing their value.
|
||||||
|
"""
|
||||||
|
d = asdict(self)
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
d[k] = v.value
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
d[k] = [x.value for x in v]
|
||||||
|
if k.endswith("_token"):
|
||||||
|
d[k] = f"<{k.upper()}>"
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user