Clean up old Accelerate checks (#24279)
* Clean up old Accelerate checks * Put back imports
This commit is contained in:
2
setup.py
2
setup.py
@@ -98,7 +98,7 @@ if stale_egg_info.exists():
|
|||||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||||
_deps = [
|
_deps = [
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"accelerate>=0.20.2",
|
"accelerate>=0.20.3",
|
||||||
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
|
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
|
||||||
"beautifulsoup4",
|
"beautifulsoup4",
|
||||||
"black~=23.1",
|
"black~=23.1",
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# 2. run `make deps_table_update``
|
# 2. run `make deps_table_update``
|
||||||
deps = {
|
deps = {
|
||||||
"Pillow": "Pillow",
|
"Pillow": "Pillow",
|
||||||
"accelerate": "accelerate>=0.20.2",
|
"accelerate": "accelerate>=0.20.3",
|
||||||
"av": "av==9.2.0",
|
"av": "av==9.2.0",
|
||||||
"beautifulsoup4": "beautifulsoup4",
|
"beautifulsoup4": "beautifulsoup4",
|
||||||
"black": "black~=23.1",
|
"black": "black~=23.1",
|
||||||
|
|||||||
@@ -82,27 +82,17 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
|
|||||||
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import __version__ as accelerate_version
|
|
||||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
|
check_tied_parameters_on_same_device,
|
||||||
find_tied_parameters,
|
find_tied_parameters,
|
||||||
|
get_balanced_memory,
|
||||||
load_offloaded_weights,
|
load_offloaded_weights,
|
||||||
offload_weight,
|
offload_weight,
|
||||||
save_offload_index,
|
save_offload_index,
|
||||||
set_module_tensor_to_device,
|
set_module_tensor_to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if version.parse(accelerate_version) > version.parse("0.11.0"):
|
|
||||||
from accelerate.utils import get_balanced_memory
|
|
||||||
else:
|
|
||||||
get_balanced_memory = None
|
|
||||||
if version.parse(accelerate_version) > version.parse("0.19.0"):
|
|
||||||
from accelerate.utils import check_tied_parameters_on_same_device
|
|
||||||
else:
|
|
||||||
check_tied_parameters_on_same_device = None
|
|
||||||
else:
|
|
||||||
find_tied_parameters = None
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import load_file as safe_load_file
|
from safetensors.torch import load_file as safe_load_file
|
||||||
@@ -2792,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
||||||
"'sequential'."
|
"'sequential'."
|
||||||
)
|
)
|
||||||
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
|
|
||||||
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
|
|
||||||
|
|
||||||
kwargs = {"no_split_module_classes": no_split_modules}
|
kwargs = {"no_split_module_classes": no_split_modules}
|
||||||
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
||||||
@@ -2803,7 +2791,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
||||||
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
||||||
)
|
)
|
||||||
if device_map != "sequential" and get_balanced_memory is not None:
|
if device_map != "sequential":
|
||||||
max_memory = get_balanced_memory(
|
max_memory = get_balanced_memory(
|
||||||
model,
|
model,
|
||||||
dtype=target_dtype,
|
dtype=target_dtype,
|
||||||
@@ -2838,8 +2826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
tied_params = find_tied_parameters(model)
|
tied_params = find_tied_parameters(model)
|
||||||
# check if we don't have tied param in different devices
|
# check if we don't have tied param in different devices
|
||||||
if check_tied_parameters_on_same_device is not None:
|
check_tied_parameters_on_same_device(tied_params, device_map)
|
||||||
check_tied_parameters_on_same_device(tied_params, device_map)
|
|
||||||
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
if resolved_archive_file.endswith(".index"):
|
if resolved_archive_file.endswith(".index"):
|
||||||
@@ -3031,7 +3018,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||||
|
|
||||||
if find_tied_parameters is not None:
|
if is_accelerate_available():
|
||||||
tied_params = find_tied_parameters(model)
|
tied_params = find_tied_parameters(model)
|
||||||
else:
|
else:
|
||||||
tied_params = []
|
tied_params = []
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ from collections.abc import Mapping
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
# Integrations must be imported before ML frameworks:
|
# Integrations must be imported before ML frameworks:
|
||||||
# isort: off
|
# isort: off
|
||||||
@@ -206,14 +204,9 @@ if is_peft_available():
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
skip_first_batches = None
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
|
from accelerate import Accelerator, skip_first_batches
|
||||||
from accelerate import __version__ as accelerate_version
|
from accelerate import __version__ as accelerate_version
|
||||||
|
|
||||||
if version.parse(accelerate_version) >= version.parse("0.16"):
|
|
||||||
from accelerate import skip_first_batches
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
if version.parse(accelerate_version) > version.parse("0.20.3"):
|
if version.parse(accelerate_version) > version.parse("0.20.3"):
|
||||||
@@ -322,6 +315,7 @@ class Trainer:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Those are used as methods of the Trainer in examples.
|
||||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
|
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1714,22 +1708,10 @@ class Trainer:
|
|||||||
logger.info(f" Continuing training from epoch {epochs_trained}")
|
logger.info(f" Continuing training from epoch {epochs_trained}")
|
||||||
logger.info(f" Continuing training from global step {self.state.global_step}")
|
logger.info(f" Continuing training from global step {self.state.global_step}")
|
||||||
if not args.ignore_data_skip:
|
if not args.ignore_data_skip:
|
||||||
if skip_first_batches is None:
|
logger.info(
|
||||||
logger.info(
|
f" Will skip the first {epochs_trained} epochs then the first"
|
||||||
f" Will skip the first {epochs_trained} epochs then the first"
|
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
||||||
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
|
)
|
||||||
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
|
|
||||||
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
|
|
||||||
" training on data already seen by your model."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f" Will skip the first {epochs_trained} epochs then the first"
|
|
||||||
f" {steps_trained_in_current_epoch} batches in the first epoch."
|
|
||||||
)
|
|
||||||
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
|
|
||||||
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
|
|
||||||
steps_trained_progress_bar.set_description("Skipping the first batches")
|
|
||||||
|
|
||||||
# Update the references
|
# Update the references
|
||||||
self.callback_handler.model = self.model
|
self.callback_handler.model = self.model
|
||||||
@@ -1787,7 +1769,7 @@ class Trainer:
|
|||||||
|
|
||||||
rng_to_sync = False
|
rng_to_sync = False
|
||||||
steps_skipped = 0
|
steps_skipped = 0
|
||||||
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
|
||||||
steps_skipped = steps_trained_in_current_epoch
|
steps_skipped = steps_trained_in_current_epoch
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user