[CLAP] Support batched inputs for CLAP. Fixes pipeline issues (#21931)
* fix pipeline * fix feature_extraction clap * you can now batch the `is_longer` attribute * add tests * fixup * add expected scores * comment on is_longert
This commit is contained in:
@@ -44,6 +44,7 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
|
||||
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
|
||||
>>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
[{'score': 0.9995999932289124, 'label': 'Sound of a dog'}, {'score': 0.00040007088682614267, 'label': 'Sound of vaccum cleaner'}]
|
||||
```
|
||||
|
||||
|
||||
@@ -118,6 +119,7 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
|
||||
inputs["text_inputs"] = [text_inputs]
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
candidate_labels = model_inputs.pop("candidate_labels")
|
||||
@@ -131,8 +133,8 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
outputs = self.model(**text_inputs, **model_inputs)
|
||||
|
||||
model_outputs = {
|
||||
"candidate_label": candidate_labels,
|
||||
"logits_per_audio": outputs.logits_per_audio,
|
||||
"candidate_labels": candidate_labels,
|
||||
"logits": outputs.logits_per_audio,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user