@@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
|
|||||||
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
|
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_list_devices(model):
|
# Check correct device map
|
||||||
list_devices = []
|
self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
|
||||||
for _, module in model.named_children():
|
|
||||||
if len(list(module.children())) > 0:
|
|
||||||
list_devices.extend(get_list_devices(module))
|
|
||||||
else:
|
|
||||||
# Do a try except since we can encounter Dropout modules that does not
|
|
||||||
# have any device set
|
|
||||||
try:
|
|
||||||
list_devices.append(next(module.parameters()).device.index)
|
|
||||||
except BaseException:
|
|
||||||
continue
|
|
||||||
return list_devices
|
|
||||||
|
|
||||||
list_devices = get_list_devices(model_parallel)
|
|
||||||
# Check that we have dispatched the model into 2 separate devices
|
|
||||||
self.assertTrue((1 in list_devices) and (0 in list_devices))
|
|
||||||
|
|
||||||
# Check that inference pass works on the model
|
# Check that inference pass works on the model
|
||||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
|
|||||||
Reference in New Issue
Block a user