Move tests/utils.py -> transformers/testing_utils.py (#5350)
This commit is contained in:
@@ -12,7 +12,7 @@ export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py and utils.py
|
||||
# Add parent directory to python path to access lightning_base.py and testing_utils.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
python finetune.py \
|
||||
--data_dir=cnn_tiny/ \
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
@@ -107,7 +108,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||
@require_multigpu
|
||||
def test_multigpu(self):
|
||||
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
Reference in New Issue
Block a user