Fix return_dict in encodec (#31646)
* fix: use return_dict parameter * fix: type checks * fix: unused imports * update: one-line if else * remove: recursive check
This commit is contained in:
@@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict or self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
chunk_length = self.config.chunk_length
|
chunk_length = self.config.chunk_length
|
||||||
if chunk_length is None:
|
if chunk_length is None:
|
||||||
@@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
>>> audio_codes = outputs.audio_codes
|
>>> audio_codes = outputs.audio_codes
|
||||||
>>> audio_values = outputs.audio_values
|
>>> audio_values = outputs.audio_values
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict or self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
if padding_mask is None:
|
if padding_mask is None:
|
||||||
padding_mask = torch.ones_like(input_values).bool()
|
padding_mask = torch.ones_like(input_values).bool()
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Audio, load_dataset
|
from datasets import Audio, load_dataset
|
||||||
@@ -385,32 +384,22 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
|
||||||
|
|
||||||
def recursive_check(tuple_object, dict_object):
|
self.assertTrue(isinstance(tuple_output, tuple))
|
||||||
if isinstance(tuple_object, (List, Tuple)):
|
self.assertTrue(isinstance(dict_output, dict))
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
|
||||||
elif isinstance(tuple_object, Dict):
|
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(
|
|
||||||
tuple_object.values(), dict_object.values()
|
|
||||||
):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
||||||
elif tuple_object is None:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
|
||||||
),
|
),
|
||||||
msg=(
|
msg=(
|
||||||
"Tuple and dict output are not equal. Difference:"
|
"Tuple and dict output are not equal. Difference:"
|
||||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
|
||||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
|
||||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
recursive_check(tuple_output, dict_output)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user