Fast transformers import part 1 (#9441)
* Don't import libs to check they are available * Don't import integrations at init * Add importlib_metdata to deps * Remove old vars references * Avoid syntax error * Adapt testing utils * Try to appease torchhub * Add dependency * Remove more private variables * Fix typo * Another typo * Refine the tf availability test
This commit is contained in:
@@ -25,7 +25,7 @@ import shutil
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
# Integrations must be imported before ML frameworks:
|
||||
@@ -143,12 +143,6 @@ if is_mlflow_available():
|
||||
|
||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
||||
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if is_ray_tune_available():
|
||||
from ray import tune
|
||||
|
||||
if is_azureml_available():
|
||||
from .integrations import AzureMLCallback
|
||||
|
||||
@@ -159,6 +153,10 @@ if is_fairscale_available():
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -611,15 +609,21 @@ class Trainer:
|
||||
return
|
||||
self.objective = self.compute_objective(metrics.copy())
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
import optuna
|
||||
|
||||
trial.report(self.objective, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
if self.state.global_step % self.args.save_steps == 0:
|
||||
self._tune_save_checkpoint()
|
||||
tune.report(objective=self.objective, **metrics)
|
||||
|
||||
def _tune_save_checkpoint(self):
|
||||
from ray import tune
|
||||
|
||||
if not self.use_tune_checkpoints:
|
||||
return
|
||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||
@@ -981,7 +985,12 @@ class Trainer:
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
run_id = trial.number
|
||||
else:
|
||||
from ray import tune
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user