Context managers (#13900)
* add `ContextManagers` for lists of contexts * fix import sorting * add `ContextManagers` tests
This commit is contained in:
committed by
GitHub
parent
f875fb0e5f
commit
0270d44f57
@@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
@@ -20,7 +22,14 @@ import transformers
|
||||
|
||||
# Try to import everything from transformers to ensure every object can be loaded.
|
||||
from transformers import * # noqa F406
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
||||
from transformers.file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
filename_to_url,
|
||||
get_from_cache,
|
||||
hf_bucket_url,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
|
||||
|
||||
@@ -40,6 +49,21 @@ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d
|
||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||
|
||||
|
||||
# Dummy contexts to test `ContextManagers`
|
||||
@contextlib.contextmanager
|
||||
def context_en():
|
||||
print("Welcome!")
|
||||
yield
|
||||
print("Bye!")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context_fr():
|
||||
print("Bonjour!")
|
||||
yield
|
||||
print("Au revoir!")
|
||||
|
||||
|
||||
def test_module_spec():
|
||||
assert transformers.__spec__ is not None
|
||||
assert importlib.util.find_spec("transformers") is not None
|
||||
@@ -85,3 +109,26 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_no_context(self, mock_stdout):
|
||||
with ContextManagers([]):
|
||||
print("Transformers are awesome!")
|
||||
# The print statement adds a new line at the end of the output
|
||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_one_context(self, mock_stdout):
|
||||
with ContextManagers([context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_two_context(self, mock_stdout):
|
||||
with ContextManagers([context_fr(), context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English and French welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||
|
||||
Reference in New Issue
Block a user