Update Bark generation configs and tests (#25409)
* update bark generation configs for more coherent parameter * make style * update bark hub repo
This commit is contained in:
@@ -894,11 +894,11 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BarkModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def model(self):
|
||||
return BarkModel.from_pretrained("ylacombe/bark-large").to(torch_device)
|
||||
return BarkModel.from_pretrained("suno/bark").to(torch_device)
|
||||
|
||||
@cached_property
|
||||
def processor(self):
|
||||
return BarkProcessor.from_pretrained("ylacombe/bark-large")
|
||||
return BarkProcessor.from_pretrained("suno/bark")
|
||||
|
||||
@cached_property
|
||||
def inputs(self):
|
||||
@@ -937,6 +937,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
output_ids = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
|
||||
@@ -957,6 +958,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
output_ids = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
|
||||
@@ -964,6 +966,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
output_ids,
|
||||
history_prompt=history_prompt,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
coarse_generation_config=self.coarse_generation_config,
|
||||
codebook_size=self.model.generation_config.codebook_size,
|
||||
@@ -994,6 +997,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
output_ids = self.model.semantic.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
)
|
||||
|
||||
@@ -1001,6 +1005,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
output_ids,
|
||||
history_prompt=history_prompt,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
semantic_generation_config=self.semantic_generation_config,
|
||||
coarse_generation_config=self.coarse_generation_config,
|
||||
codebook_size=self.model.generation_config.codebook_size,
|
||||
@@ -1040,9 +1045,16 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
input_ids = self.inputs
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.generate(**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7)
|
||||
self.model.generate(
|
||||
**input_ids, do_sample=False, coarse_do_sample=True, coarse_temperature=0.7, fine_temperature=0.3
|
||||
**input_ids, do_sample=False, temperature=1.0, coarse_do_sample=True, coarse_temperature=0.7
|
||||
)
|
||||
self.model.generate(
|
||||
**input_ids,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
coarse_do_sample=True,
|
||||
coarse_temperature=0.7,
|
||||
fine_temperature=0.3,
|
||||
)
|
||||
self.model.generate(
|
||||
**input_ids,
|
||||
@@ -1061,7 +1073,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
with torch.no_grad():
|
||||
# standard generation
|
||||
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1088,7 +1100,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
|
||||
|
||||
# output with cpu offload
|
||||
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||
output_with_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
|
||||
|
||||
# checks if same output
|
||||
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())
|
||||
|
||||
Reference in New Issue
Block a user