fix cuda OOM by using single Prior (#20486)
* fix cuda OOM by using single Prior * only send to device when used * use custom model
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers.trainer_utils import set_seed
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import JukeboxModel, JukeboxTokenizer
|
from transformers import JukeboxModel, JukeboxPrior, JukeboxTokenizer
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -312,7 +312,7 @@ class Jukebox5bModelTester(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_slow_sampling(self):
|
def test_slow_sampling(self):
|
||||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().to("cuda")
|
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
|
||||||
labels = [i.cuda() for i in self.prepare_inputs(self.model_id)]
|
labels = [i.cuda() for i in self.prepare_inputs(self.model_id)]
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
@@ -335,10 +335,11 @@ class Jukebox5bModelTester(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_fp16_slow_sampling(self):
|
def test_fp16_slow_sampling(self):
|
||||||
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().half().to("cuda")
|
prior_id = "ArthurZ/jukebox_prior_0"
|
||||||
labels = [i.cuda() for i in self.prepare_inputs(self.model_id)]
|
model = JukeboxPrior.from_pretrained(prior_id, min_duration=0).eval().half().to("cuda")
|
||||||
|
|
||||||
|
labels = self.prepare_inputs(prior_id)[0].cuda()
|
||||||
|
metadata = model.get_metadata(labels, 0, 7680, 0)
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
zs = [torch.zeros(1, 0, dtype=torch.long).cuda() for _ in range(3)]
|
outputs = model.sample(1, metadata=metadata, sample_tokens=60)
|
||||||
zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False)
|
torch.testing.assert_allclose(outputs[0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2))
|
||||||
torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user