Generate: Load generation config when device_map is passed (#25413)
This commit is contained in:
@@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"'sequential'."
|
"'sequential'."
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = {"no_split_module_classes": no_split_modules}
|
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
||||||
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
||||||
kwargs["special_dtypes"] = special_dtypes
|
device_map_kwargs["special_dtypes"] = special_dtypes
|
||||||
elif len(special_dtypes) > 0:
|
elif len(special_dtypes) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
||||||
@@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
dtype=target_dtype,
|
dtype=target_dtype,
|
||||||
low_zero=(device_map == "balanced_low_0"),
|
low_zero=(device_map == "balanced_low_0"),
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
**kwargs,
|
**device_map_kwargs,
|
||||||
)
|
)
|
||||||
kwargs["max_memory"] = max_memory
|
device_map_kwargs["max_memory"] = max_memory
|
||||||
# Make sure tied weights are tied before creating the device map.
|
# Make sure tied weights are tied before creating the device map.
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
|
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
||||||
|
|
||||||
if load_in_8bit or load_in_4bit:
|
if load_in_8bit or load_in_4bit:
|
||||||
# The LM head / tied weights or any last module can stay on disk / CPU
|
# The LM head / tied weights or any last module can stay on disk / CPU
|
||||||
@@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# If it is a model with generation capabilities, attempt to load the generation config
|
# If it is a model with generation capabilities, attempt to load the generation config
|
||||||
if model.can_generate():
|
if model.can_generate() and pretrained_model_name_or_path is not None:
|
||||||
try:
|
try:
|
||||||
model.generation_config = GenerationConfig.from_pretrained(
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
@@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
_from_pipeline=from_pipeline,
|
_from_pipeline=from_pipeline,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except (OSError, TypeError):
|
except OSError:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Generation config file not found, using a generation config created from the model config."
|
"Generation config file not found, using a generation config created from the model config."
|
||||||
)
|
)
|
||||||
@@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Dispatch model with hooks on all devices if necessary
|
# Dispatch model with hooks on all devices if necessary
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
|
device_map_kwargs = {
|
||||||
|
"device_map": device_map,
|
||||||
|
"offload_dir": offload_folder,
|
||||||
|
"offload_index": offload_index,
|
||||||
|
}
|
||||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||||
kwargs["skip_keys"] = model._skip_keys_device_placement
|
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||||
dispatch_model(model, **kwargs)
|
dispatch_model(model, **device_map_kwargs)
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
if loading_info is None:
|
if loading_info is None:
|
||||||
|
|||||||
@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
||||||
|
|
||||||
|
def test_generation_config_is_loaded_with_model(self):
|
||||||
|
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
|
||||||
|
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
|
||||||
|
|
||||||
|
# 1. Load without further parameters
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
|
||||||
|
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||||
|
|
||||||
|
# 2. Load with `device_map`
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
|
||||||
|
)
|
||||||
|
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user