Context managers (#13900)
* add `ContextManagers` for lists of contexts * fix import sorting * add `ContextManagers` tests
This commit is contained in:
committed by
GitHub
parent
f875fb0e5f
commit
0270d44f57
@@ -30,14 +30,14 @@ import tarfile
|
|||||||
import tempfile
|
import tempfile
|
||||||
import types
|
import types
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from contextlib import contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
@@ -2365,3 +2365,21 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||||||
return f"{username}/{model_id}"
|
return f"{username}/{model_id}"
|
||||||
else:
|
else:
|
||||||
return f"{organization}/{model_id}"
|
return f"{organization}/{model_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManagers:
|
||||||
|
"""
|
||||||
|
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
|
||||||
|
in the `fastcore` library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context_managers: List[ContextManager]):
|
||||||
|
self.context_managers = context_managers
|
||||||
|
self.stack = ExitStack()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
for context_manager in self.context_managers:
|
||||||
|
self.stack.enter_context(context_manager)
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.stack.__exit__(*args, **kwargs)
|
||||||
|
|||||||
@@ -12,7 +12,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import importlib
|
import importlib
|
||||||
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -20,7 +22,14 @@ import transformers
|
|||||||
|
|
||||||
# Try to import everything from transformers to ensure every object can be loaded.
|
# Try to import everything from transformers to ensure every object can be loaded.
|
||||||
from transformers import * # noqa F406
|
from transformers import * # noqa F406
|
||||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
from transformers.file_utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
ContextManagers,
|
||||||
|
filename_to_url,
|
||||||
|
get_from_cache,
|
||||||
|
hf_bucket_url,
|
||||||
|
)
|
||||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||||
|
|
||||||
|
|
||||||
@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d
|
|||||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy contexts to test `ContextManagers`
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def context_en():
|
||||||
|
print("Welcome!")
|
||||||
|
yield
|
||||||
|
print("Bye!")
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def context_fr():
|
||||||
|
print("Bonjour!")
|
||||||
|
yield
|
||||||
|
print("Au revoir!")
|
||||||
|
|
||||||
|
|
||||||
def test_module_spec():
|
def test_module_spec():
|
||||||
assert transformers.__spec__ is not None
|
assert transformers.__spec__ is not None
|
||||||
assert importlib.util.find_spec("transformers") is not None
|
assert importlib.util.find_spec("transformers") is not None
|
||||||
@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase):
|
|||||||
filepath = get_from_cache(url, force_download=True)
|
filepath = get_from_cache(url, force_download=True)
|
||||||
metadata = filename_to_url(filepath)
|
metadata = filename_to_url(filepath)
|
||||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManagerTests(unittest.TestCase):
|
||||||
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
|
def test_no_context(self, mock_stdout):
|
||||||
|
with ContextManagers([]):
|
||||||
|
print("Transformers are awesome!")
|
||||||
|
# The print statement adds a new line at the end of the output
|
||||||
|
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||||
|
|
||||||
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
|
def test_one_context(self, mock_stdout):
|
||||||
|
with ContextManagers([context_en()]):
|
||||||
|
print("Transformers are awesome!")
|
||||||
|
# The output should be wrapped with an English welcome and goodbye
|
||||||
|
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||||
|
|
||||||
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
|
def test_two_context(self, mock_stdout):
|
||||||
|
with ContextManagers([context_fr(), context_en()]):
|
||||||
|
print("Transformers are awesome!")
|
||||||
|
# The output should be wrapped with an English and French welcome and goodbye
|
||||||
|
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user