Context managers (#13900)

* add `ContextManagers` for lists of contexts

* fix import sorting

* add `ContextManagers` tests
This commit is contained in:
Leandro von Werra
2021-10-20 14:15:47 +02:00
committed by GitHub
parent f875fb0e5f
commit 0270d44f57
2 changed files with 68 additions and 3 deletions

View File

@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib
import io
import unittest
import requests
@@ -20,7 +22,14 @@ import transformers
# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
from transformers.file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
ContextManagers,
filename_to_url,
get_from_cache,
hf_bucket_url,
)
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
# Dummy contexts to test `ContextManagers`
@contextlib.contextmanager
def context_en():
print("Welcome!")
yield
print("Bye!")
@contextlib.contextmanager
def context_fr():
print("Bonjour!")
yield
print("Au revoir!")
def test_module_spec():
assert transformers.__spec__ is not None
assert importlib.util.find_spec("transformers") is not None
@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase):
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
class ContextManagerTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")