Improve a add-new-pipeline docs a bit (#14485)
This commit is contained in:
@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
|
|||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
|
|
||||||
class MyPipeline(Pipeline):
|
class MyPipeline(Pipeline):
|
||||||
def _sanitize_parameters(self, **kwargs)
|
def _sanitize_parameters(self, **kwargs):
|
||||||
preprocess_kwargs = {}
|
preprocess_kwargs = {}
|
||||||
if "maybe_arg" in kwargs:
|
if "maybe_arg" in kwargs:
|
||||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||||
return preprocess_kwargs, {}, {}
|
return preprocess_kwargs, {}, {}
|
||||||
|
|
||||||
def preprocess(self, inputs, maybe_arg=2)
|
def preprocess(self, inputs, maybe_arg=2):
|
||||||
model_input = Tensor(....)
|
model_input = Tensor(....)
|
||||||
return {"model_input": model_input}
|
return {"model_input": model_input}
|
||||||
|
|
||||||
def _forward(self, model_inputs)
|
def _forward(self, model_inputs):
|
||||||
# model_inputs == {"model_input": model_input}
|
# model_inputs == {"model_input": model_input}
|
||||||
oututs = self.model(**model_inputs)
|
outputs = self.model(**model_inputs)
|
||||||
# Maybe {"logits": Tensor(...)}
|
# Maybe {"logits": Tensor(...)}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs)
|
def postprocess(self, model_outputs):
|
||||||
best_class = model_outputs["logits"].softmax(-1)
|
best_class = model_outputs["logits"].softmax(-1)
|
||||||
return best_class
|
return best_class
|
||||||
|
|
||||||
@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
|
|||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
|
|
||||||
def postprocess(self, model_outputs, top_k=5)
|
def postprocess(self, model_outputs, top_k=5):
|
||||||
best_class = model_outputs["logits"].softmax(-1)
|
best_class = model_outputs["logits"].softmax(-1)
|
||||||
# Add logic to handle top_k
|
# Add logic to handle top_k
|
||||||
return best_class
|
return best_class
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs)
|
def _sanitize_parameters(self, **kwargs):
|
||||||
preprocess_kwargs = {}
|
preprocess_kwargs = {}
|
||||||
if "maybe_arg" in kwargs:
|
if "maybe_arg" in kwargs:
|
||||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||||
|
|||||||
Reference in New Issue
Block a user