fix the issue that the output dict of jit model could not get [0] (#21354)

This commit is contained in:
Wang, Yi
2023-01-30 22:23:55 +08:00
committed by GitHub
parent c749bd405e
commit f3a7befffa

View File

@@ -239,7 +239,8 @@ class TokenClassificationPipeline(Pipeline):
if self.framework == "tf":
logits = self.model(model_inputs.data)[0]
else:
logits = self.model(**model_inputs)[0]
output = self.model(**model_inputs)
logits = output["logits"] if isinstance(output, dict) else output[0]
return {
"logits": logits,