From 82a1fc7256bf27f83aec3a93543b6d156add09cf Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:18:01 -0700 Subject: [PATCH] Fix return_dict in encodec (#31646) * fix: use return_dict parameter * fix: type checks * fix: unused imports * update: one-line if else * remove: recursive check --- .../models/encodec/modeling_encodec.py | 4 +- tests/models/encodec/test_modeling_encodec.py | 39 +++++++------------ 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 9627742b9e..f325a6adbe 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - return_dict = return_dict or self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.return_dict chunk_length = self.config.chunk_length if chunk_length is None: @@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel): >>> audio_codes = outputs.audio_codes >>> audio_values = outputs.audio_values ```""" - return_dict = return_dict or self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.return_dict if padding_mask is None: padding_mask = torch.ones_like(input_values).bool() diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index e4f66d8564..0a023894d8 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -19,7 +19,6 @@ import inspect import os import tempfile import unittest -from typing import Dict, List, Tuple import numpy as np from datasets import Audio, load_dataset @@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip( - tuple_object.values(), dict_object.values() - ): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" - f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." - ), - ) + self.assertTrue(isinstance(tuple_output, tuple)) + self.assertTrue(isinstance(dict_output, dict)) - recursive_check(tuple_output, dict_output) + for tuple_value, dict_value in zip(tuple_output, dict_output.values()): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:" + f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has" + f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}." + ), + ) for model_class in self.all_model_classes: model = model_class(config)