diff --git a/src/transformers/models/bark/configuration_bark.py b/src/transformers/models/bark/configuration_bark.py index 635cb0aa1a..15efb11dc7 100644 --- a/src/transformers/models/bark/configuration_bark.py +++ b/src/transformers/models/bark/configuration_bark.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Union from ...configuration_utils import PretrainedConfig from ...utils import add_start_docstrings, logging -from ..auto import AutoConfig +from ..auto import CONFIG_MAPPING logger = logging.get_logger(__name__) @@ -299,7 +299,8 @@ class BarkConfig(PretrainedConfig): self.semantic_config = BarkSemanticConfig(**semantic_config) self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config) self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config) - self.codec_config = AutoConfig.for_model(**codec_config) + codec_model_type = codec_config["model_type"] if "model_type" in codec_config else "encodec" + self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config) self.initializer_range = initializer_range @@ -311,7 +312,7 @@ class BarkConfig(PretrainedConfig): semantic_config: BarkSemanticConfig, coarse_acoustics_config: BarkCoarseConfig, fine_acoustics_config: BarkFineConfig, - codec_config: AutoConfig, + codec_config: PretrainedConfig, **kwargs, ): r""" diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 4141e5c188..82a902ded4 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -22,6 +22,7 @@ import unittest from transformers import ( BarkCoarseConfig, + BarkConfig, BarkFineConfig, BarkSemanticConfig, is_torch_available, @@ -37,6 +38,7 @@ from transformers.utils import cached_property from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask +from ..encodec.test_modeling_encodec import EncodecModelTester if is_torch_available(): @@ -72,8 +74,6 @@ class BarkSemanticModelTester: initializer_range=0.02, n_codes_total=8, # for BarkFineModel n_codes_given=1, # for BarkFineModel - config_class=None, - model_class=None, ): self.parent = parent self.batch_size = batch_size @@ -98,8 +98,6 @@ class BarkSemanticModelTester: self.n_codes_given = n_codes_given self.is_encoder_decoder = False - self.config_class = config_class - self.model_class = model_class def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -121,7 +119,7 @@ class BarkSemanticModelTester: return config, inputs_dict def get_config(self): - return self.config_class( + return BarkSemanticConfig( vocab_size=self.vocab_size, output_vocab_size=self.output_vocab_size, hidden_size=self.hidden_size, @@ -137,6 +135,7 @@ class BarkSemanticModelTester: def get_pipeline_config(self): config = self.get_config() config.vocab_size = 300 + config.output_vocab_size = 300 return config def prepare_config_and_inputs_for_common(self): @@ -144,7 +143,7 @@ class BarkSemanticModelTester: return config, inputs_dict def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): - model = self.model_class(config=config).to(torch_device).eval() + model = BarkSemanticModel(config=config).to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] @@ -211,8 +210,6 @@ class BarkCoarseModelTester: initializer_range=0.02, n_codes_total=8, # for BarkFineModel n_codes_given=1, # for BarkFineModel - config_class=None, - model_class=None, ): self.parent = parent self.batch_size = batch_size @@ -237,8 +234,6 @@ class BarkCoarseModelTester: self.n_codes_given = n_codes_given self.is_encoder_decoder = False - self.config_class = config_class - self.model_class = model_class def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -260,7 +255,7 @@ class BarkCoarseModelTester: return config, inputs_dict def get_config(self): - return self.config_class( + return BarkCoarseConfig( vocab_size=self.vocab_size, output_vocab_size=self.output_vocab_size, hidden_size=self.hidden_size, @@ -276,6 +271,7 @@ class BarkCoarseModelTester: def get_pipeline_config(self): config = self.get_config() config.vocab_size = 300 + config.output_vocab_size = 300 return config def prepare_config_and_inputs_for_common(self): @@ -283,7 +279,7 @@ class BarkCoarseModelTester: return config, inputs_dict def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): - model = self.model_class(config=config).to(torch_device).eval() + model = BarkCoarseModel(config=config).to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] @@ -350,8 +346,6 @@ class BarkFineModelTester: initializer_range=0.02, n_codes_total=8, # for BarkFineModel n_codes_given=1, # for BarkFineModel - config_class=None, - model_class=None, ): self.parent = parent self.batch_size = batch_size @@ -376,8 +370,6 @@ class BarkFineModelTester: self.n_codes_given = n_codes_given self.is_encoder_decoder = False - self.config_class = config_class - self.model_class = model_class def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length, self.n_codes_total], self.vocab_size) @@ -403,7 +395,7 @@ class BarkFineModelTester: return config, inputs_dict def get_config(self): - return self.config_class( + return BarkFineConfig( vocab_size=self.vocab_size, output_vocab_size=self.output_vocab_size, hidden_size=self.hidden_size, @@ -419,6 +411,7 @@ class BarkFineModelTester: def get_pipeline_config(self): config = self.get_config() config.vocab_size = 300 + config.output_vocab_size = 300 return config def prepare_config_and_inputs_for_common(self): @@ -426,7 +419,7 @@ class BarkFineModelTester: return config, inputs_dict def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): - model = self.model_class(config=config).to(torch_device).eval() + model = BarkFineModel(config=config).to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] @@ -473,6 +466,79 @@ class BarkFineModelTester: self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) +class BarkModelTester: + def __init__( + self, + parent, + semantic_kwargs=None, + coarse_acoustics_kwargs=None, + fine_acoustics_kwargs=None, + codec_kwargs=None, + is_training=False, # for now training is not supported + ): + if semantic_kwargs is None: + semantic_kwargs = {} + if coarse_acoustics_kwargs is None: + coarse_acoustics_kwargs = {} + if fine_acoustics_kwargs is None: + fine_acoustics_kwargs = {} + if codec_kwargs is None: + codec_kwargs = {} + + self.parent = parent + self.semantic_model_tester = BarkSemanticModelTester(parent, **semantic_kwargs) + self.coarse_acoustics_model_tester = BarkCoarseModelTester(parent, **coarse_acoustics_kwargs) + self.fine_acoustics_model_tester = BarkFineModelTester(parent, **fine_acoustics_kwargs) + self.codec_model_tester = EncodecModelTester(parent, **codec_kwargs) + + self.is_training = is_training + + def prepare_config_and_inputs(self): + # TODO: @Yoach: Preapre `inputs_dict` + inputs_dict = {} + config = self.get_config() + + return config, inputs_dict + + def get_config(self): + return BarkConfig.from_sub_model_configs( + self.semantic_model_tester.get_config(), + self.coarse_acoustics_model_tester.get_config(), + self.fine_acoustics_model_tester.get_config(), + self.codec_model_tester.get_config(), + ) + + def get_pipeline_config(self): + config = self.get_config() + + # follow the `get_pipeline_config` of the sub component models + config.semantic_config.vocab_size = 300 + config.coarse_acoustics_config.vocab_size = 300 + config.fine_acoustics_config.vocab_size = 300 + + config.semantic_config.output_vocab_size = 300 + config.coarse_acoustics_config.output_vocab_size = 300 + config.fine_acoustics_config.output_vocab_size = 300 + + return config + + def prepare_config_and_inputs_for_common(self): + # TODO: @Yoach + pass + # return config, inputs_dict + + +# Need this class in oder to create tiny model for `bark` +# TODO (@Yoach) Implement actual test methods +@unittest.skip("So far all tests will fail.") +class BarkModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (BarkModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = BarkModelTester(self) + self.config_tester = ConfigTester(self, config_class=BarkConfig, n_embd=37) + + @require_torch class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (BarkSemanticModel,) if is_torch_available() else () @@ -488,9 +554,7 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te test_resize_embeddings = True def setUp(self): - self.model_tester = BarkSemanticModelTester( - self, config_class=BarkSemanticConfig, model_class=BarkSemanticModel - ) + self.model_tester = BarkSemanticModelTester(self) self.config_tester = ConfigTester(self, config_class=BarkSemanticConfig, n_embd=37) def test_config(self): @@ -556,7 +620,7 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test test_resize_embeddings = True def setUp(self): - self.model_tester = BarkCoarseModelTester(self, config_class=BarkCoarseConfig, model_class=BarkCoarseModel) + self.model_tester = BarkCoarseModelTester(self) self.config_tester = ConfigTester(self, config_class=BarkCoarseConfig, n_embd=37) def test_config(self): @@ -623,7 +687,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): test_resize_embeddings = True def setUp(self): - self.model_tester = BarkFineModelTester(self, config_class=BarkFineConfig, model_class=BarkFineModel) + self.model_tester = BarkFineModelTester(self) self.config_tester = ConfigTester(self, config_class=BarkFineConfig, n_embd=37) def test_config(self): diff --git a/utils/create_dummy_models.py b/utils/create_dummy_models.py index f1b8736b59..87f3326504 100644 --- a/utils/create_dummy_models.py +++ b/utils/create_dummy_models.py @@ -974,6 +974,10 @@ def get_token_id_from_tokenizer(token_id_name, tokenizer, original_token_id): def get_config_overrides(config_class, processors): + # `Bark` configuration is too special. Let's just not handle this for now. + if config_class.__name__ == "BarkConfig": + return {} + config_overrides = {} # Check if there is any tokenizer (prefer fast version if any)