Decorators for deprecation and named arguments validation (#30799)
* Fix do_reduce_labels for maskformer image processor * Deprecate reduce_labels in favor to do_reduce_labels * Deprecate reduce_labels in favor to do_reduce_labels (segformer) * Deprecate reduce_labels in favor to do_reduce_labels (oneformer) * Deprecate reduce_labels in favor to do_reduce_labels (maskformer) * Deprecate reduce_labels in favor to do_reduce_labels (mask2former) * Fix typo * Update mask2former test * fixup * Update segmentation examples * Update docs * Fixup * Imports fixup * Add deprecation decorator draft * Add deprecation decorator * Fixup * Add deprecate_kwarg decorator * Validate kwargs decorator * Kwargs validation (beit) * fixup * Kwargs validation (mask2former) * Kwargs validation (maskformer) * Kwargs validation (oneformer) * Kwargs validation (segformer) * Better message * Fix oneformer processor save-load test * Update src/transformers/utils/deprecation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Update src/transformers/utils/deprecation.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Better handle classmethod warning * Fix typo, remove warn * Add header * Docs and `additional_message` * Move to filter decorator ot generic * Proper deprecation for semantic segm scripts * Add to __init__ and update import * Basic tests for filter decorator * Fix doc * Override `to_dict()` to pop depracated `_max_size` * Pop unused parameters * Fix trailing whitespace * Add test for deprecation * Add deprecation warning control parameter * Update generic test * Fixup deprecation tests * Introduce init service kwargs * Revert popping unused params * Revert oneformer test * Allow "metadata" to pass * Better docs * Fix test * Add notion in docstring * Fix notification for both names * Add func name to warning message * Fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4fa4dcb2be
commit
517df566f5
@@ -14,12 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_flax, require_tf, require_torch
|
||||
from transformers.utils import (
|
||||
expand_dims,
|
||||
filter_out_non_signature_kwargs,
|
||||
flatten_dict,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
@@ -198,3 +200,74 @@ class GenericTester(unittest.TestCase):
|
||||
x = np.random.randn(3, 4)
|
||||
t = jnp.array(x)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
||||
|
||||
|
||||
class ValidationDecoratorTester(unittest.TestCase):
|
||||
def test_cases_no_warning(self):
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
# basic test
|
||||
@filter_out_non_signature_kwargs()
|
||||
def func1(a):
|
||||
return a
|
||||
|
||||
result = func1(1)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
# include extra kwarg
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||
def func2(a, **kwargs):
|
||||
return a, kwargs
|
||||
|
||||
a, kwargs = func2(1)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
a, kwargs = func2(1, extra_arg=2)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||
|
||||
# multiple extra kwargs
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||
def func3(a, **kwargs):
|
||||
return a, kwargs
|
||||
|
||||
a, kwargs = func3(2)
|
||||
self.assertEqual(a, 2)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
a, kwargs = func3(3, extra_arg2=3)
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(kwargs, {"extra_arg2": 3})
|
||||
|
||||
a, kwargs = func3(1, extra_arg=2, extra_arg2=3)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||
|
||||
# Check that no warnings were raised
|
||||
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
|
||||
|
||||
def test_cases_with_warnings(self):
|
||||
@filter_out_non_signature_kwargs()
|
||||
def func1(a):
|
||||
return a
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
func1(1, extra_arg=2)
|
||||
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||
def func2(a, **kwargs):
|
||||
return kwargs
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
kwargs = func2(1, extra_arg=2, extra_arg2=3)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||
def func3(a, **kwargs):
|
||||
return kwargs
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||
|
||||
Reference in New Issue
Block a user