Enable pytest live log and show warning logs on GitHub Actions CI runs (#35912)
* fix * remove * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import unittest
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -27,6 +28,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||
from transformers.utils.logging import _get_library_root_logger
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
@@ -102,9 +104,12 @@ class StreamerTester(unittest.TestCase):
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||
|
||||
root = _get_library_root_logger()
|
||||
with patch.object(root, "propagate", False):
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||
|
||||
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
|
||||
# re-tokenized, must only contain one token
|
||||
|
||||
Reference in New Issue
Block a user