Ray Tune Integration Updates (#12134)
* fix * fixes * add back to scheduled tests * formatting * Update integrations.py
This commit is contained in:
4
.github/workflows/self-scheduled.yml
vendored
4
.github/workflows/self-scheduled.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
run: |
|
||||
apt -y update && apt install -y libsndfile1-dev
|
||||
pip install --upgrade pip
|
||||
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
|
||||
pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
run: |
|
||||
apt -y update && apt install -y libsndfile1-dev
|
||||
pip install --upgrade pip
|
||||
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
|
||||
pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
|
||||
@@ -163,11 +163,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
local_trainer._tune_save_checkpoint()
|
||||
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
|
||||
|
||||
if not trainer._memory_tracker.skip_memory_metrics:
|
||||
from .trainer_utils import TrainerMemoryTracker
|
||||
|
||||
logger.warning(
|
||||
"Memory tracking for your Trainer is currently "
|
||||
"enabled. Automatically disabling the memory tracker "
|
||||
"since the memory tracker is not serializable."
|
||||
)
|
||||
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
|
||||
|
||||
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
|
||||
# while doing the ray hp search.
|
||||
|
||||
_tb_writer = trainer.pop_callback(TensorBoardCallback)
|
||||
trainer.model = None
|
||||
|
||||
# Setup default `resources_per_trial`.
|
||||
if "resources_per_trial" not in kwargs:
|
||||
# Default to 1 CPU and 1 GPU (if applicable) per trial.
|
||||
@@ -194,7 +204,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
trainer.use_tune_checkpoints = True
|
||||
if kwargs["keep_checkpoints_num"] > 1:
|
||||
logger.warning(
|
||||
f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
|
||||
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
|
||||
"Checkpoints are usually huge, "
|
||||
"consider setting `keep_checkpoints_num=1`."
|
||||
)
|
||||
|
||||
@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
def ray_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
|
||||
def model_init(config):
|
||||
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)
|
||||
if config is None:
|
||||
a = 0
|
||||
b = 0
|
||||
else:
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||
|
||||
return RegressionPreTrainedModel(model_config)
|
||||
|
||||
@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
|
||||
)
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
self.ray_hyperparameter_search()
|
||||
|
||||
def test_hyperparameter_search_ray_client(self):
|
||||
import ray
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
with ray_start_client_server():
|
||||
assert ray.util.client.ray.is_connected()
|
||||
self.ray_hyperparameter_search()
|
||||
|
||||
Reference in New Issue
Block a user