Now supporting pathlike in pipelines too. (#20030)
This commit is contained in:
@@ -21,6 +21,7 @@ import os
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from numpy import isin
|
from numpy import isin
|
||||||
@@ -638,6 +639,8 @@ def pipeline(
|
|||||||
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
|
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
|
||||||
" or a path/identifier to a pretrained model when providing feature_extractor."
|
" or a path/identifier to a pretrained model when providing feature_extractor."
|
||||||
)
|
)
|
||||||
|
if isinstance(model, Path):
|
||||||
|
model = str(model)
|
||||||
|
|
||||||
# Config is the primordial information item.
|
# Config is the primordial information item.
|
||||||
# Instantiate config if needed
|
# Instantiate config if needed
|
||||||
|
|||||||
@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
self.assertEqual(pipe._batch_size, 2)
|
self.assertEqual(pipe._batch_size, 2)
|
||||||
self.assertEqual(pipe._num_workers, 1)
|
self.assertEqual(pipe._num_workers, 1)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_pathlike(self):
|
||||||
|
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert")
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
pipe.save_pretrained(d)
|
||||||
|
path = Path(d)
|
||||||
|
newpipe = pipeline(task="text-classification", model=path)
|
||||||
|
self.assertIsInstance(newpipe, TextClassificationPipeline)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pipeline_override(self):
|
def test_pipeline_override(self):
|
||||||
class MyPipeline(TextClassificationPipeline):
|
class MyPipeline(TextClassificationPipeline):
|
||||||
|
|||||||
Reference in New Issue
Block a user