Fast transformers import part 1 (#9441)

* Don't import libs to check they are available

* Don't import integrations at init

* Add importlib_metdata to deps

* Remove old vars references

* Avoid syntax error

* Adapt testing utils

* Try to appease torchhub

* Add dependency

* Remove more private variables

* Fix typo

* Another typo

* Refine the tf availability test
This commit is contained in:
Sylvain Gugger
2021-01-06 12:17:24 -05:00
committed by GitHub
parent c89f1bc92e
commit 0c96262f7d
13 changed files with 280 additions and 360 deletions

View File

@@ -25,7 +25,7 @@ import shutil
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
# Integrations must be imported before ML frameworks:
@@ -143,12 +143,6 @@ if is_mlflow_available():
DEFAULT_CALLBACKS.append(MLflowCallback)
if is_optuna_available():
import optuna
if is_ray_tune_available():
from ray import tune
if is_azureml_available():
from .integrations import AzureMLCallback
@@ -159,6 +153,10 @@ if is_fairscale_available():
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
@@ -611,15 +609,21 @@ class Trainer:
return
self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
trial.report(self.objective, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.state.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
from ray import tune
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
@@ -981,7 +985,12 @@ class Trainer:
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
else:
from ray import tune
run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else: