[s2s] Test hub configs in self-scheduled CI (#6809)
This commit is contained in:
@@ -13,9 +13,10 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from transformers.hf_api import HfApi
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
@@ -116,6 +117,25 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_and_cuda
|
||||||
|
def test_hub_configs(self):
|
||||||
|
"""I put require_torch_and_cuda cause I only want this to run with self-scheduled."""
|
||||||
|
|
||||||
|
model_list = HfApi().model_list()
|
||||||
|
org = "sshleifer"
|
||||||
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||||
|
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
||||||
|
failures = []
|
||||||
|
for m in model_ids:
|
||||||
|
if m in allowed_to_be_broken:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
AutoConfig.from_pretrained(m)
|
||||||
|
except Exception:
|
||||||
|
failures.append(m)
|
||||||
|
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||||
|
|
||||||
@require_multigpu
|
@require_multigpu
|
||||||
def test_multigpu(self):
|
def test_multigpu(self):
|
||||||
updates = dict(
|
updates = dict(
|
||||||
|
|||||||
Reference in New Issue
Block a user