Generate: Load generation config when device_map is passed (#25413)

This commit is contained in:
Joao Gante
2023-08-10 10:54:26 +01:00
committed by GitHub
parent d0839f1a74
commit 3e41cf13fc
2 changed files with 28 additions and 10 deletions

View File

@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
def test_generation_config_is_loaded_with_model(self):
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
# 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
self.assertEqual(model.generation_config.transformers_version, "foo")
# 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained(
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
)
self.assertEqual(model.generation_config.transformers_version, "foo")
@require_torch
@is_staging_test