Fixed issue #21039 and added test for low_cpu_mem_usage
This commit is contained in:
Susnato Dhar
2023-01-12 14:33:13 +05:30
committed by GitHub
parent e849e5bb4a
commit b5be744d3c
2 changed files with 26 additions and 1 deletions

View File

@@ -3166,6 +3166,27 @@ class ModelUtilsTest(TestCasePlus):
):
_ = ModelWithHead.from_pretrained(tmp_dir)
@require_torch_gpu
def test_pretrained_low_mem_new_config(self):
# Checking for 1 model(the same one which was described in the issue) .
model_ids = ["gpt2"]
for model_id in model_ids:
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
model_config.n_layer = 48
model_config.n_head = 25
model_config.n_embd = 1600
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
config=model_config,
ignore_mismatched_sizes=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
@require_torch
@is_staging_test