Populate torch_dtype from model to pipeline (#28940)
* Populate torch_dtype from model to pipeline Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * use property Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * lint Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> * Remove default handling Signed-off-by: B-Step62 <yuki.watanabe@databricks.com> --------- Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
This commit is contained in:
@@ -861,7 +861,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
raise ValueError(f"{device} unrecognized or not available.")
|
raise ValueError(f"{device} unrecognized or not available.")
|
||||||
else:
|
else:
|
||||||
self.device = device if device is not None else -1
|
self.device = device if device is not None else -1
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.binary_output = binary_output
|
self.binary_output = binary_output
|
||||||
|
|
||||||
# We shouldn't call `model.to()` for models loaded with accelerate
|
# We shouldn't call `model.to()` for models loaded with accelerate
|
||||||
@@ -964,6 +964,13 @@ class Pipeline(_ScikitCompat):
|
|||||||
"""
|
"""
|
||||||
return self(X)
|
return self(X)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def torch_dtype(self) -> Optional["torch.dtype"]:
|
||||||
|
"""
|
||||||
|
Torch dtype of the model (if it's Pytorch model), `None` otherwise.
|
||||||
|
"""
|
||||||
|
return getattr(self.model, "dtype", None)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def device_placement(self):
|
def device_placement(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -199,6 +199,29 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
|
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
|
||||||
self.assertEqual(len(outputs), 20)
|
self.assertEqual(len(outputs), 20)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_dtype_property(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_id = "hf-internal-testing/tiny-random-distilbert"
|
||||||
|
|
||||||
|
# If dtype is specified in the pipeline constructor, the property should return that type
|
||||||
|
pipe = pipeline(model=model_id, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(pipe.torch_dtype, torch.float16)
|
||||||
|
|
||||||
|
# If the underlying model changes dtype, the property should return the new type
|
||||||
|
pipe.model.to(torch.bfloat16)
|
||||||
|
self.assertEqual(pipe.torch_dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
# If dtype is NOT specified in the pipeline constructor, the property should just return
|
||||||
|
# the dtype of the underlying model (default)
|
||||||
|
pipe = pipeline(model=model_id)
|
||||||
|
self.assertEqual(pipe.torch_dtype, torch.float32)
|
||||||
|
|
||||||
|
# If underlying model doesn't have dtype property, simply return None
|
||||||
|
pipe.model = None
|
||||||
|
self.assertIsNone(pipe.torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class PipelineScikitCompatTest(unittest.TestCase):
|
class PipelineScikitCompatTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user