Doc styler examples (#14953)
* Fix bad examples * Add black formatting to style_doc * Use first nonempty line * Put it at the right place * Don't add spaces to empty lines * Better templates * Deal with triple quotes in docstrings * Result of style_doc * Enable mdx treatment and fix code examples in MDXs * Result of doc styler on doc source files * Last fixes * Break copy from
This commit is contained in:
@@ -422,14 +422,14 @@ Let's depict the GPU requirements in the following table:
|
||||
|
||||
For example, here is a test that must be run only when there are 2 or more GPUs available and pytorch is installed:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@require_torch_multi_gpu
|
||||
def test_example_with_multi_gpu():
|
||||
```
|
||||
|
||||
If a test requires `tensorflow` use the `require_tf` decorator. For example:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@require_tf
|
||||
def test_tf_thing_with_tensorflow():
|
||||
```
|
||||
@@ -437,7 +437,7 @@ def test_tf_thing_with_tensorflow():
|
||||
These decorators can be stacked. For example, if a test is slow and requires at least one GPU under pytorch, here is
|
||||
how to set it up:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_example_slow_on_gpu():
|
||||
@@ -446,7 +446,7 @@ def test_example_slow_on_gpu():
|
||||
Some decorators like `@parametrized` rewrite test names, therefore `@require_*` skip decorators have to be listed
|
||||
last for them to work correctly. Here is an example of the correct usage:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@parameterized.expand(...)
|
||||
@require_torch_multi_gpu
|
||||
def test_integration_foo():
|
||||
@@ -461,7 +461,8 @@ Inside tests:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import get_gpu_count
|
||||
n_gpu = get_gpu_count() # works with torch and tf
|
||||
|
||||
n_gpu = get_gpu_count() # works with torch and tf
|
||||
```
|
||||
|
||||
### Distributed training
|
||||
@@ -544,12 +545,16 @@ the test, but then there is no way of running that test for just one set of argu
|
||||
# test_this1.py
|
||||
import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
|
||||
class TestMathUnitTest(unittest.TestCase):
|
||||
@parameterized.expand([
|
||||
("negative", -1.5, -2.0),
|
||||
("integer", 1, 1.0),
|
||||
("large fraction", 1.6, 1),
|
||||
])
|
||||
@parameterized.expand(
|
||||
[
|
||||
("negative", -1.5, -2.0),
|
||||
("integer", 1, 1.0),
|
||||
("large fraction", 1.6, 1),
|
||||
]
|
||||
)
|
||||
def test_floor(self, name, input, expected):
|
||||
assert_equal(math.floor(input), expected)
|
||||
```
|
||||
@@ -601,6 +606,8 @@ Here is the same example, this time using `pytest`'s `parametrize` marker:
|
||||
```python
|
||||
# test_this2.py
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name, input, expected",
|
||||
[
|
||||
@@ -669,6 +676,8 @@ To start using those all you need is to make sure that the test resides in a sub
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
class PathExampleTest(TestCasePlus):
|
||||
def test_something_involving_local_locations(self):
|
||||
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
|
||||
@@ -679,6 +688,8 @@ If you don't need to manipulate paths via `pathlib` or you just need a path as a
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
class PathExampleTest(TestCasePlus):
|
||||
def test_something_involving_stringified_locations(self):
|
||||
examples_dir = self.examples_dir_str
|
||||
@@ -700,6 +711,8 @@ Here is an example of its usage:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
class ExamplesTests(TestCasePlus):
|
||||
def test_whatever(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@@ -759,6 +772,7 @@ If you need to temporary override `sys.path` to import from another test for exa
|
||||
```python
|
||||
import os
|
||||
from transformers.testing_utils import ExtendSysPath
|
||||
|
||||
bindir = os.path.abspath(os.path.dirname(__file__))
|
||||
with ExtendSysPath(f"{bindir}/.."):
|
||||
from test_trainer import TrainerIntegrationCommon # noqa
|
||||
@@ -786,20 +800,20 @@ code that's buggy causes some bad state that will affect other tests, do not use
|
||||
|
||||
- Here is how to skip whole test unconditionally:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@unittest.skip("this bug needs to be fixed")
|
||||
def test_feature_x():
|
||||
```
|
||||
|
||||
or via pytest:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@pytest.mark.skip(reason="this bug needs to be fixed")
|
||||
```
|
||||
|
||||
or the `xfail` way:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@pytest.mark.xfail
|
||||
def test_feature_x():
|
||||
```
|
||||
@@ -816,6 +830,7 @@ or the whole module:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
if not pytest.config.getoption("--custom-flag"):
|
||||
pytest.skip("--custom-flag is missing, skipping tests", allow_module_level=True)
|
||||
```
|
||||
@@ -835,21 +850,21 @@ docutils = pytest.importorskip("docutils", minversion="0.3")
|
||||
|
||||
- Skip a test based on a condition:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@pytest.mark.skipif(sys.version_info < (3,6), reason="requires python3.6 or higher")
|
||||
def test_feature_x():
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@unittest.skipIf(torch_device == "cpu", "Can't do half precision")
|
||||
def test_feature_x():
|
||||
```
|
||||
|
||||
or skip the whole module:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
@pytest.mark.skipif(sys.platform == 'win32', reason="does not run on windows")
|
||||
class TestClass():
|
||||
def test_feature_x(self):
|
||||
@@ -863,7 +878,7 @@ The library of tests is ever-growing, and some of the tests take minutes to run,
|
||||
an hour for the test suite to complete on CI. Therefore, with some exceptions for essential tests, slow tests should be
|
||||
marked as in the example below:
|
||||
|
||||
```python
|
||||
```python no-style
|
||||
from transformers.testing_utils import slow
|
||||
@slow
|
||||
def test_integration_foo():
|
||||
@@ -878,8 +893,8 @@ RUN_SLOW=1 pytest tests
|
||||
Some decorators like `@parameterized` rewrite test names, therefore `@slow` and the rest of the skip decorators
|
||||
`@require_*` have to be listed last for them to work correctly. Here is an example of the correct usage:
|
||||
|
||||
```python
|
||||
@parameterized.expand(...)
|
||||
```python no-style
|
||||
@parameteriz ed.expand(...)
|
||||
@slow
|
||||
def test_integration_foo():
|
||||
```
|
||||
@@ -935,13 +950,21 @@ In order to test functions that write to `stdout` and/or `stderr`, the test can
|
||||
|
||||
```python
|
||||
import sys
|
||||
def print_to_stdout(s): print(s)
|
||||
def print_to_stderr(s): sys.stderr.write(s)
|
||||
|
||||
|
||||
def print_to_stdout(s):
|
||||
print(s)
|
||||
|
||||
|
||||
def print_to_stderr(s):
|
||||
sys.stderr.write(s)
|
||||
|
||||
|
||||
def test_result_and_stdout(capsys):
|
||||
msg = "Hello"
|
||||
print_to_stdout(msg)
|
||||
print_to_stderr(msg)
|
||||
out, err = capsys.readouterr() # consume the captured output streams
|
||||
out, err = capsys.readouterr() # consume the captured output streams
|
||||
# optional: if you want to replay the consumed streams:
|
||||
sys.stdout.write(out)
|
||||
sys.stderr.write(err)
|
||||
@@ -954,10 +977,13 @@ And, of course, most of the time, `stderr` will come as a part of an exception,
|
||||
a case:
|
||||
|
||||
```python
|
||||
def raise_exception(msg): raise ValueError(msg)
|
||||
def raise_exception(msg):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def test_something_exception():
|
||||
msg = "Not a good value"
|
||||
error = ''
|
||||
error = ""
|
||||
try:
|
||||
raise_exception(msg)
|
||||
except Exception as e:
|
||||
@@ -970,7 +996,12 @@ Another approach to capturing stdout is via `contextlib.redirect_stdout`:
|
||||
```python
|
||||
from io import StringIO
|
||||
from contextlib import redirect_stdout
|
||||
def print_to_stdout(s): print(s)
|
||||
|
||||
|
||||
def print_to_stdout(s):
|
||||
print(s)
|
||||
|
||||
|
||||
def test_result_and_stdout():
|
||||
msg = "Hello"
|
||||
buffer = StringIO()
|
||||
@@ -993,6 +1024,7 @@ some `\r`'s in it or not, so it's a simple:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
function_that_writes_to_stdout()
|
||||
print(cs.out)
|
||||
@@ -1002,17 +1034,19 @@ Here is a full test example:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
|
||||
msg = "Secret message\r"
|
||||
final = "Hello World"
|
||||
with CaptureStdout() as cs:
|
||||
print(msg + final)
|
||||
assert cs.out == final+"\n", f"captured: {cs.out}, expecting {final}"
|
||||
assert cs.out == final + "\n", f"captured: {cs.out}, expecting {final}"
|
||||
```
|
||||
|
||||
If you'd like to capture `stderr` use the `CaptureStderr` class instead:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import CaptureStderr
|
||||
|
||||
with CaptureStderr() as cs:
|
||||
function_that_writes_to_stderr()
|
||||
print(cs.err)
|
||||
@@ -1022,6 +1056,7 @@ If you need to capture both streams at once, use the parent `CaptureStd` class:
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
with CaptureStd() as cs:
|
||||
function_that_writes_to_stdout_and_stderr()
|
||||
print(cs.err, cs.out)
|
||||
@@ -1044,7 +1079,7 @@ logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.models.bart.tokenization_bart")
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.info(msg)
|
||||
assert cl.out, msg+"\n"
|
||||
assert cl.out, msg + "\n"
|
||||
```
|
||||
|
||||
### Testing with environment variables
|
||||
@@ -1054,6 +1089,8 @@ If you want to test the impact of environment variables for a specific test you
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import mockenv
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
@mockenv(TRANSFORMERS_VERBOSITY="error")
|
||||
def test_env_override(self):
|
||||
@@ -1065,6 +1102,8 @@ multiple local paths. A helper class `transformers.test_utils.TestCasePlus` come
|
||||
|
||||
```python
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
class EnvExampleTest(TestCasePlus):
|
||||
def test_external_prog(self):
|
||||
env = self.get_env()
|
||||
@@ -1089,16 +1128,20 @@ seed = 42
|
||||
|
||||
# python RNG
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
# pytorch RNGs
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# numpy RNG
|
||||
import numpy as np
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
# tf RNG
|
||||
|
||||
Reference in New Issue
Block a user