[s2s] Test hub configs in self-scheduled CI (#6809)

This commit is contained in:
Sam Shleifer
2020-08-28 17:05:52 -04:00
committed by GitHub
parent 3cac867fac
commit 5ab21b072f

View File

@@ -13,9 +13,10 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import lightning_base import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.hf_api import HfApi
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from .distillation import distill_main, evaluate_checkpoint from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main from .finetune import SummarizationModule, main
@@ -116,6 +117,25 @@ class TestSummarizationDistiller(unittest.TestCase):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls return cls
@slow
@require_torch_and_cuda
def test_hub_configs(self):
"""I put require_torch_and_cuda cause I only want this to run with self-scheduled."""
model_list = HfApi().model_list()
org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
failures = []
for m in model_ids:
if m in allowed_to_be_broken:
continue
try:
AutoConfig.from_pretrained(m)
except Exception:
failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_multigpu @require_multigpu
def test_multigpu(self): def test_multigpu(self):
updates = dict( updates = dict(