Fix low cpu mem usage tests (#30808)
* Fix tests * fix udop failing test * remove skip * style
This commit is contained in:
@@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
|
|||||||
# get weights from encoder position bias
|
# get weights from encoder position bias
|
||||||
self.relative_bias = self._get_relative_bias(config)
|
self.relative_bias = self._get_relative_bias(config)
|
||||||
|
|
||||||
# tie weights of original position bias of encoder
|
def _tie_weights(self):
|
||||||
for bias in self.relative_bias.biases:
|
for bias in self.relative_bias.biases:
|
||||||
if isinstance(bias, RelativePositionBias1D):
|
if isinstance(bias, RelativePositionBias1D):
|
||||||
self._tie_or_clone_weights(
|
self._tie_or_clone_weights(
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import os.path
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@@ -444,7 +443,6 @@ class ModelTesterMixin:
|
|||||||
@slow
|
@slow
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@unittest.skip("Need to fix since we have a device mismatch")
|
|
||||||
def test_save_load_low_cpu_mem_usage(self):
|
def test_save_load_low_cpu_mem_usage(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||||
@@ -457,7 +455,6 @@ class ModelTesterMixin:
|
|||||||
@slow
|
@slow
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@unittest.skip("Need to fix since we have a device mismatch")
|
|
||||||
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||||
@@ -471,7 +468,6 @@ class ModelTesterMixin:
|
|||||||
@slow
|
@slow
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@unittest.skip("Need to fix since we have a device mismatch")
|
|
||||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -482,6 +478,8 @@ class ModelTesterMixin:
|
|||||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||||
|
|
||||||
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
|
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
|
||||||
|
from accelerate.utils.modeling import named_module_tensors
|
||||||
|
|
||||||
# Load the low usage and the normal models.
|
# Load the low usage and the normal models.
|
||||||
model_low_usage, loading_info = model_class.from_pretrained(
|
model_low_usage, loading_info = model_class.from_pretrained(
|
||||||
saved_model_path,
|
saved_model_path,
|
||||||
@@ -496,16 +494,13 @@ class ModelTesterMixin:
|
|||||||
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
|
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
|
||||||
# subsequently loaded with the correct values and onto the correct device. We check if there are any
|
# subsequently loaded with the correct values and onto the correct device. We check if there are any
|
||||||
# remaining params that were not properly loaded.
|
# remaining params that were not properly loaded.
|
||||||
for name, param in model_low_usage.named_parameters():
|
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
param.device,
|
tensor.device,
|
||||||
torch.device("meta"),
|
torch.device("meta"),
|
||||||
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
|
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tests moving the model to a device other than meta.
|
|
||||||
model_low_usage.to(torch_device)
|
|
||||||
|
|
||||||
# Check that the parameters are equal.
|
# Check that the parameters are equal.
|
||||||
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
|
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
|
||||||
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
|
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user