From 0270d44f5741d9f983d24dd2f9e7ef952341093a Mon Sep 17 00:00:00 2001 From: Leandro von Werra Date: Wed, 20 Oct 2021 14:15:47 +0200 Subject: [PATCH] Context managers (#13900) * add `ContextManagers` for lists of contexts * fix import sorting * add `ContextManagers` tests --- src/transformers/file_utils.py | 22 +++++++++++++-- tests/test_file_utils.py | 49 +++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 1439b31292..fdac94ee3d 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -30,14 +30,14 @@ import tarfile import tempfile import types from collections import OrderedDict, UserDict -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from dataclasses import fields from enum import Enum from functools import partial, wraps from hashlib import sha256 from pathlib import Path from types import ModuleType -from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union +from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse from uuid import uuid4 from zipfile import ZipFile, is_zipfile @@ -2365,3 +2365,21 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{username}/{model_id}" else: return f"{organization}/{model_id}" + + +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + for context_manager in self.context_managers: + self.stack.enter_context(context_manager) + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index c32e3db1fa..4de449b344 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import importlib +import io import unittest import requests @@ -20,7 +22,14 @@ import transformers # Try to import everything from transformers to ensure every object can be loaded. from transformers import * # noqa F406 -from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url +from transformers.file_utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + ContextManagers, + filename_to_url, + get_from_cache, + hf_bucket_url, +) from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER @@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d # Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes +# Dummy contexts to test `ContextManagers` +@contextlib.contextmanager +def context_en(): + print("Welcome!") + yield + print("Bye!") + + +@contextlib.contextmanager +def context_fr(): + print("Bonjour!") + yield + print("Au revoir!") + + def test_module_spec(): assert transformers.__spec__ is not None assert importlib.util.find_spec("transformers") is not None @@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase): filepath = get_from_cache(url, force_download=True) metadata = filename_to_url(filepath) self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"')) + + +class ContextManagerTests(unittest.TestCase): + @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) + def test_no_context(self, mock_stdout): + with ContextManagers([]): + print("Transformers are awesome!") + # The print statement adds a new line at the end of the output + self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n") + + @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) + def test_one_context(self, mock_stdout): + with ContextManagers([context_en()]): + print("Transformers are awesome!") + # The output should be wrapped with an English welcome and goodbye + self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n") + + @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) + def test_two_context(self, mock_stdout): + with ContextManagers([context_fr(), context_en()]): + print("Transformers are awesome!") + # The output should be wrapped with an English and French welcome and goodbye + self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")