[CLAP] Fix few broken things (#21670)
* add `is_longer` * fix docstring * fix config class * fix loss * fix all doctests * fix order * fix last failing tests --------- Co-authored-by: arthur.zucker@gmail.com <arthur.zucker@gmail.com>
This commit is contained in:
@@ -898,8 +898,8 @@ class ClapAudioEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
is_longer: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||
@@ -1673,7 +1673,7 @@ class ClapPreTrainedModel(PreTrainedModel):
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ClapTextConfig
|
||||
config_class = ClapConfig
|
||||
base_model_prefix = "clap"
|
||||
supports_gradient_checkpointing = False
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
|
||||
@@ -1746,7 +1746,7 @@ class ClapAudioModel(ClapPreTrainedModel):
|
||||
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.audio_emmbeds
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -2069,6 +2069,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
is_longer: Optional[torch.BoolTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
@@ -2108,6 +2109,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
|
||||
audio_outputs = self.audio_model(
|
||||
input_features=input_features,
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -2141,7 +2143,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
loss = None
|
||||
if return_loss:
|
||||
caption_loss = contrastive_loss(logits_per_text)
|
||||
audio_loss = contrastive_loss(logits_per_text.t())
|
||||
audio_loss = contrastive_loss(logits_per_audio.t())
|
||||
loss = (caption_loss + audio_loss) / 2.0
|
||||
|
||||
if not return_dict:
|
||||
@@ -2203,7 +2205,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
>>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> text_embeds = outputs.text_embeds
|
||||
|
||||
@@ -268,7 +268,7 @@ class ClapAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_name in CLAP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = ClapAudioModelWithProjection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "visual_projection"))
|
||||
self.assertTrue(hasattr(model, "audio_projection"))
|
||||
|
||||
|
||||
class ClapTextModelTester:
|
||||
|
||||
Reference in New Issue
Block a user