Remove all traces of low_cpu_mem_usage (#38792)
* remove it from all py files * remove it from the doc * remove it from examples * style * remove traces of _fast_init * Update test_peft_integration.py * CIs
This commit is contained in:
@@ -1109,14 +1109,16 @@ class T5ModelFp16Tests(unittest.TestCase):
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"google-t5/t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
"google-t5/t5-small",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"google-t5/t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
"google-t5/t5-small",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
Reference in New Issue
Block a user