Context managers (#13900)
* add `ContextManagers` for lists of contexts * fix import sorting * add `ContextManagers` tests
This commit is contained in:
committed by
GitHub
parent
f875fb0e5f
commit
0270d44f57
@@ -30,14 +30,14 @@ import tarfile
|
||||
import tempfile
|
||||
import types
|
||||
from collections import OrderedDict, UserDict
|
||||
from contextlib import contextmanager
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
@@ -2365,3 +2365,21 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
class ContextManagers:
|
||||
"""
|
||||
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
|
||||
in the `fastcore` library.
|
||||
"""
|
||||
|
||||
def __init__(self, context_managers: List[ContextManager]):
|
||||
self.context_managers = context_managers
|
||||
self.stack = ExitStack()
|
||||
|
||||
def __enter__(self):
|
||||
for context_manager in self.context_managers:
|
||||
self.stack.enter_context(context_manager)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.stack.__exit__(*args, **kwargs)
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
@@ -20,7 +22,14 @@ import transformers
|
||||
|
||||
# Try to import everything from transformers to ensure every object can be loaded.
|
||||
from transformers import * # noqa F406
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
||||
from transformers.file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
filename_to_url,
|
||||
get_from_cache,
|
||||
hf_bucket_url,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
|
||||
|
||||
@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d
|
||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||
|
||||
|
||||
# Dummy contexts to test `ContextManagers`
|
||||
@contextlib.contextmanager
|
||||
def context_en():
|
||||
print("Welcome!")
|
||||
yield
|
||||
print("Bye!")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context_fr():
|
||||
print("Bonjour!")
|
||||
yield
|
||||
print("Au revoir!")
|
||||
|
||||
|
||||
def test_module_spec():
|
||||
assert transformers.__spec__ is not None
|
||||
assert importlib.util.find_spec("transformers") is not None
|
||||
@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_no_context(self, mock_stdout):
|
||||
with ContextManagers([]):
|
||||
print("Transformers are awesome!")
|
||||
# The print statement adds a new line at the end of the output
|
||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_one_context(self, mock_stdout):
|
||||
with ContextManagers([context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_two_context(self, mock_stdout):
|
||||
with ContextManagers([context_fr(), context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English and French welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||
|
||||
Reference in New Issue
Block a user