Fix return_dict in encodec (#31646)

* fix: use return_dict parameter

* fix: type checks

* fix: unused imports

* update: one-line if else

* remove: recursive check
This commit is contained in:
Jacky Lee
2024-06-28 04:18:01 -07:00
committed by GitHub
parent 5e89b335ab
commit 82a1fc7256
2 changed files with 16 additions and 27 deletions

View File

@@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
return_dict = return_dict or self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
chunk_length = self.config.chunk_length chunk_length = self.config.chunk_length
if chunk_length is None: if chunk_length is None:
@@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
>>> audio_codes = outputs.audio_codes >>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values >>> audio_values = outputs.audio_values
```""" ```"""
return_dict = return_dict or self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if padding_mask is None: if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool() padding_mask = torch.ones_like(input_values).bool()

View File

@@ -19,7 +19,6 @@ import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple
import numpy as np import numpy as np
from datasets import Audio, load_dataset from datasets import Audio, load_dataset
@@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
def recursive_check(tuple_object, dict_object): self.assertTrue(isinstance(tuple_output, tuple))
if isinstance(tuple_object, (List, Tuple)): self.assertTrue(isinstance(dict_output, dict))
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output) for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
),
)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)