[CLAP] Fix few broken things (#21670)
* add `is_longer` * fix docstring * fix config class * fix loss * fix all doctests * fix order * fix last failing tests --------- Co-authored-by: arthur.zucker@gmail.com <arthur.zucker@gmail.com>
This commit is contained in:
@@ -898,8 +898,8 @@ class ClapAudioEncoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features,
|
input_features,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
is_longer: Optional[torch.FloatTensor] = None,
|
is_longer: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
output_hidden_states_before_downsampling: Optional[bool] = False,
|
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||||||
@@ -1673,7 +1673,7 @@ class ClapPreTrainedModel(PreTrainedModel):
|
|||||||
models.
|
models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = ClapTextConfig
|
config_class = ClapConfig
|
||||||
base_model_prefix = "clap"
|
base_model_prefix = "clap"
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
|
||||||
@@ -1746,7 +1746,7 @@ class ClapAudioModel(ClapPreTrainedModel):
|
|||||||
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
|
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
|
||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_state = outputs.audio_emmbeds
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@@ -2069,6 +2069,7 @@ class ClapModel(ClapPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
input_features: Optional[torch.FloatTensor] = None,
|
input_features: Optional[torch.FloatTensor] = None,
|
||||||
|
is_longer: Optional[torch.BoolTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
return_loss: Optional[bool] = None,
|
return_loss: Optional[bool] = None,
|
||||||
@@ -2108,6 +2109,7 @@ class ClapModel(ClapPreTrainedModel):
|
|||||||
|
|
||||||
audio_outputs = self.audio_model(
|
audio_outputs = self.audio_model(
|
||||||
input_features=input_features,
|
input_features=input_features,
|
||||||
|
is_longer=is_longer,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -2141,7 +2143,7 @@ class ClapModel(ClapPreTrainedModel):
|
|||||||
loss = None
|
loss = None
|
||||||
if return_loss:
|
if return_loss:
|
||||||
caption_loss = contrastive_loss(logits_per_text)
|
caption_loss = contrastive_loss(logits_per_text)
|
||||||
audio_loss = contrastive_loss(logits_per_text.t())
|
audio_loss = contrastive_loss(logits_per_audio.t())
|
||||||
loss = (caption_loss + audio_loss) / 2.0
|
loss = (caption_loss + audio_loss) / 2.0
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -2203,7 +2205,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
|||||||
>>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
|
>>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
>>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
||||||
|
|
||||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
>>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")
|
||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> text_embeds = outputs.text_embeds
|
>>> text_embeds = outputs.text_embeds
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ class ClapAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in CLAP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in CLAP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
model = ClapAudioModelWithProjection.from_pretrained(model_name)
|
model = ClapAudioModelWithProjection.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertTrue(hasattr(model, "visual_projection"))
|
self.assertTrue(hasattr(model, "audio_projection"))
|
||||||
|
|
||||||
|
|
||||||
class ClapTextModelTester:
|
class ClapTextModelTester:
|
||||||
|
|||||||
Reference in New Issue
Block a user