Add Ray Tune hyperparameter search integration test (#10414)
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
require_datasets,
|
require_datasets,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
|
require_ray,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -80,6 +81,12 @@ class RegressionDataset:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RegressionTrainingArguments(TrainingArguments):
|
||||||
|
a: float = 0.0
|
||||||
|
b: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class RepeatDataset:
|
class RepeatDataset:
|
||||||
def __init__(self, x, length=64):
|
def __init__(self, x, length=64):
|
||||||
self.x = x
|
self.x = x
|
||||||
@@ -200,7 +207,8 @@ if is_torch_available():
|
|||||||
optimizers = kwargs.pop("optimizers", (None, None))
|
optimizers = kwargs.pop("optimizers", (None, None))
|
||||||
output_dir = kwargs.pop("output_dir", "./regression")
|
output_dir = kwargs.pop("output_dir", "./regression")
|
||||||
model_init = kwargs.pop("model_init", None)
|
model_init = kwargs.pop("model_init", None)
|
||||||
args = TrainingArguments(output_dir, **kwargs)
|
|
||||||
|
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
|
||||||
return Trainer(
|
return Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@@ -973,7 +981,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
class TrainerHyperParameterIntegrationTest(unittest.TestCase):
|
class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
args = TrainingArguments(".")
|
args = TrainingArguments(".")
|
||||||
self.n_epochs = args.num_train_epochs
|
self.n_epochs = args.num_train_epochs
|
||||||
@@ -1014,3 +1022,49 @@ class TrainerHyperParameterIntegrationTest(unittest.TestCase):
|
|||||||
model_init=model_init,
|
model_init=model_init,
|
||||||
)
|
)
|
||||||
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
|
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_ray
|
||||||
|
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
args = TrainingArguments(".")
|
||||||
|
self.n_epochs = args.num_train_epochs
|
||||||
|
self.batch_size = args.train_batch_size
|
||||||
|
|
||||||
|
def test_hyperparameter_search(self):
|
||||||
|
class MyTrialShortNamer(TrialShortNamer):
|
||||||
|
DEFAULTS = {"a": 0, "b": 0}
|
||||||
|
|
||||||
|
def hp_space(trial):
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
return {
|
||||||
|
"a": tune.randint(-4, 4),
|
||||||
|
"b": tune.randint(-4, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
def model_init(config):
|
||||||
|
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)
|
||||||
|
|
||||||
|
return RegressionPreTrainedModel(model_config)
|
||||||
|
|
||||||
|
def hp_name(params):
|
||||||
|
return MyTrialShortNamer.shortname(params)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
learning_rate=0.1,
|
||||||
|
logging_steps=1,
|
||||||
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
|
num_train_epochs=4,
|
||||||
|
disable_tqdm=True,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
logging_dir="runs",
|
||||||
|
run_name="test",
|
||||||
|
model_init=model_init,
|
||||||
|
)
|
||||||
|
trainer.hyperparameter_search(
|
||||||
|
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user