Improve a add-new-pipeline docs a bit (#14485)
This commit is contained in:
@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
|
||||
from transformers import Pipeline
|
||||
|
||||
class MyPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs)
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "maybe_arg" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, inputs, maybe_arg=2)
|
||||
def preprocess(self, inputs, maybe_arg=2):
|
||||
model_input = Tensor(....)
|
||||
return {"model_input": model_input}
|
||||
|
||||
def _forward(self, model_inputs)
|
||||
def _forward(self, model_inputs):
|
||||
# model_inputs == {"model_input": model_input}
|
||||
oututs = self.model(**model_inputs)
|
||||
outputs = self.model(**model_inputs)
|
||||
# Maybe {"logits": Tensor(...)}
|
||||
return outputs
|
||||
|
||||
def postprocess(self, model_outputs)
|
||||
def postprocess(self, model_outputs):
|
||||
best_class = model_outputs["logits"].softmax(-1)
|
||||
return best_class
|
||||
|
||||
@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
|
||||
.. code-block::
|
||||
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5)
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
best_class = model_outputs["logits"].softmax(-1)
|
||||
# Add logic to handle top_k
|
||||
return best_class
|
||||
|
||||
def _sanitize_parameters(self, **kwargs)
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "maybe_arg" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
|
||||
Reference in New Issue
Block a user