Fix bug when requesting input normalization with EnCodec (#34756)

* EnCodec: unsqueeze padding mask

* add test for normalization
This commit is contained in:
Francesco Cariaggi
2025-01-07 11:50:02 +01:00
committed by GitHub
parent 96bf3d6cc5
commit f408d55448
2 changed files with 20 additions and 5 deletions

View File

@@ -39,7 +39,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import EncodecModel
from transformers import EncodecFeatureExtractor, EncodecModel
def prepare_inputs_dict(
@@ -111,6 +111,19 @@ class EncodecModelTester:
return config, inputs_dict
def prepare_config_and_inputs_for_normalization(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
config.normalize = True
processor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sampling_rate)
input_values = list(input_values.cpu().numpy())
inputs_dict = processor(
input_values, sampling_rate=config.sampling_rate, padding=True, return_tensors="pt"
).to(torch_device)
return config, inputs_dict
def get_config(self):
return EncodecConfig(
audio_channels=self.num_channels,
@@ -125,9 +138,7 @@ class EncodecModelTester:
def create_and_check_model_forward(self, config, inputs_dict):
model = EncodecModel(config=config).to(torch_device).eval()
input_values = inputs_dict["input_values"]
result = model(input_values)
result = model(**inputs_dict)
self.parent.assertEqual(
result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
)
@@ -435,6 +446,10 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def test_model_forward_with_normalization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_normalization()
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def normalize(arr):
norm = np.linalg.norm(arr)