fixing name position_embeddings to object_queries (#24652)
* fixing name position_embeddings to object_queries * [fix] renaming variable and docstring do object queries * [fix] comment position_embedding to object queries * [feat] changes from make-fix-copies to keep consistency * Revert "[feat] changes from make-fix-copies to keep consistency" This reverts commit 56e3e9ede1d32f7aeefba707ddfaf12c9b4b9e7e. * [tests] fix wrong expected score * [fix] wrong assignment causing wrong tensor shapes * [fix] fixing position_embeddings to object queries to keep consistency (make fix copies) * [fix] make fix copies, renaming position_embeddings to object_queries * [fix] positional_embeddingss to object queries, fixes from make fix copies * [fix] comments frmo make fix copies * [fix] adding args validation to keep version support * [fix] adding args validation to keep version support -conditional detr * [fix] adding args validation to keep version support - maskformer * [style] make fixup style fixes * [feat] adding args checking * [feat] fixcopies and args checking * make fixup * make fixup --------- Co-authored-by: Lorenzobattistela <lorenzobattistela@gmail.com>
This commit is contained in:
committed by
GitHub
parent
39c37fe45c
commit
99c3d44906
@@ -564,34 +564,79 @@ class DetrAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||||
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
|
||||||
return tensor if position_embeddings is None else tensor + position_embeddings
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
return tensor if object_queries is None else tensor + object_queries
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
key_value_position_embeddings: Optional[torch.Tensor] = None,
|
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
position_embeddings = kwargs.pop("position_ebmeddings", None)
|
||||||
|
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
|
||||||
|
)
|
||||||
|
spatial_position_embeddings = key_value_position_embeddings
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
batch_size, target_len, embed_dim = hidden_states.size()
|
batch_size, target_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
# add position embeddings to the hidden states before projecting to queries and keys
|
# add position embeddings to the hidden states before projecting to queries and keys
|
||||||
if position_embeddings is not None:
|
if object_queries is not None:
|
||||||
hidden_states_original = hidden_states
|
hidden_states_original = hidden_states
|
||||||
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
||||||
|
|
||||||
# add key-value position embeddings to the key value states
|
# add key-value position embeddings to the key value states
|
||||||
if key_value_position_embeddings is not None:
|
if spatial_position_embeddings is not None:
|
||||||
key_value_states_original = key_value_states
|
key_value_states_original = key_value_states
|
||||||
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
|
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
@@ -799,8 +844,9 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor = None,
|
object_queries: torch.Tensor = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -808,16 +854,33 @@ class ConditionalDetrEncoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
|
Object queries (also called content embeddings), to be added to the hidden states.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights = self.self_attn(
|
hidden_states, attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -894,13 +957,14 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
query_position_embeddings: Optional[torch.Tensor] = None,
|
query_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
query_sine_embed: Optional[torch.Tensor] = None,
|
query_sine_embed: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
is_first: Optional[bool] = False,
|
is_first: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -908,11 +972,11 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*):
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object_queries that are added to the queries and keys
|
||||||
in the cross-attention layer.
|
in the cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object_queries that are added to the queries and keys
|
||||||
in the self-attention layer.
|
in the self-attention layer.
|
||||||
encoder_hidden_states (`torch.FloatTensor`):
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -923,6 +987,22 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
|||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# ========== Begin of Self-Attention =============
|
# ========== Begin of Self-Attention =============
|
||||||
@@ -963,7 +1043,7 @@ class ConditionalDetrDecoderLayer(nn.Module):
|
|||||||
batch_size, num_queries, n_model = q_content.shape
|
batch_size, num_queries, n_model = q_content.shape
|
||||||
_, source_len, _ = k_content.shape
|
_, source_len, _ = k_content.shape
|
||||||
|
|
||||||
k_pos = self.ca_kpos_proj(position_embeddings)
|
k_pos = self.ca_kpos_proj(object_queries)
|
||||||
|
|
||||||
# For the first decoder layer, we concatenate the positional embedding predicted from
|
# For the first decoder layer, we concatenate the positional embedding predicted from
|
||||||
# the object query (the positional embedding) into the original query (key) in DETR.
|
# the object query (the positional embedding) into the original query (key) in DETR.
|
||||||
@@ -1159,7 +1239,7 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|||||||
|
|
||||||
Small tweak for ConditionalDETR:
|
Small tweak for ConditionalDETR:
|
||||||
|
|
||||||
- position_embeddings are added to the forward pass.
|
- object_queries are added to the forward pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: ConditionalDetrConfig
|
config: ConditionalDetrConfig
|
||||||
@@ -1182,10 +1262,11 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1200,8 +1281,8 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
Position embeddings that are added to the queries and keys in each self-attention layer.
|
Object queries that are added to the queries in each self-attention layer.
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
@@ -1212,6 +1293,22 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1241,11 +1338,11 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
|
|||||||
if to_drop:
|
if to_drop:
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
# we add position_embeddings as extra input to the encoder_layer
|
# we add object_queries as extra input to the encoder_layer
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1272,7 +1369,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
|
|
||||||
Some small tweaks for Conditional DETR:
|
Some small tweaks for Conditional DETR:
|
||||||
|
|
||||||
- position_embeddings and query_position_embeddings are added to the forward pass.
|
- object_queries and query_position_embeddings are added to the forward pass.
|
||||||
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1305,11 +1402,12 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
query_position_embeddings=None,
|
query_position_embeddings=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1333,7 +1431,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
- 1 for pixels that are real (i.e. **not masked**),
|
- 1 for pixels that are real (i.e. **not masked**),
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
- 0 for pixels that are padding (i.e. **masked**).
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
||||||
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
||||||
@@ -1346,6 +1444,22 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1413,7 +1527,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
combined_attention_mask,
|
combined_attention_mask,
|
||||||
position_embeddings,
|
object_queries,
|
||||||
query_position_embeddings,
|
query_position_embeddings,
|
||||||
query_sine_embed,
|
query_sine_embed,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
@@ -1425,7 +1539,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
query_sine_embed=query_sine_embed,
|
query_sine_embed=query_sine_embed,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
@@ -1493,8 +1607,8 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|||||||
|
|
||||||
# Create backbone + positional encoding
|
# Create backbone + positional encoding
|
||||||
backbone = ConditionalDetrConvEncoder(config)
|
backbone = ConditionalDetrConvEncoder(config)
|
||||||
position_embeddings = build_position_encoding(config)
|
object_queries = build_position_encoding(config)
|
||||||
self.backbone = ConditionalDetrConvModel(backbone, position_embeddings)
|
self.backbone = ConditionalDetrConvModel(backbone, object_queries)
|
||||||
|
|
||||||
# Create projection layer
|
# Create projection layer
|
||||||
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
||||||
@@ -1578,7 +1692,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|||||||
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
||||||
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
||||||
# pixel_mask should be of shape (batch_size, height, width)
|
# pixel_mask should be of shape (batch_size, height, width)
|
||||||
features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
|
features, object_queries_list = self.backbone(pixel_values, pixel_mask)
|
||||||
|
|
||||||
# get final feature map and downsampled mask
|
# get final feature map and downsampled mask
|
||||||
feature_map, mask = features[-1]
|
feature_map, mask = features[-1]
|
||||||
@@ -1589,21 +1703,21 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|||||||
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
||||||
projected_feature_map = self.input_projection(feature_map)
|
projected_feature_map = self.input_projection(feature_map)
|
||||||
|
|
||||||
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
||||||
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
||||||
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
||||||
|
|
||||||
flattened_mask = mask.flatten(1)
|
flattened_mask = mask.flatten(1)
|
||||||
|
|
||||||
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
|
||||||
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
||||||
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=flattened_features,
|
inputs_embeds=flattened_features,
|
||||||
attention_mask=flattened_mask,
|
attention_mask=flattened_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1616,7 +1730,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
||||||
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
queries = torch.zeros_like(query_position_embeddings)
|
queries = torch.zeros_like(query_position_embeddings)
|
||||||
|
|
||||||
@@ -1624,7 +1738,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
|
|||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
inputs_embeds=queries,
|
inputs_embeds=queries,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=flattened_mask,
|
encoder_attention_mask=flattened_mask,
|
||||||
@@ -1940,29 +2054,29 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|||||||
if pixel_mask is None:
|
if pixel_mask is None:
|
||||||
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
||||||
|
|
||||||
# First, get list of feature maps and position embeddings
|
# First, get list of feature maps and object_queries
|
||||||
features, position_embeddings_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
features, object_queries_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
||||||
|
|
||||||
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
||||||
feature_map, mask = features[-1]
|
feature_map, mask = features[-1]
|
||||||
batch_size, num_channels, height, width = feature_map.shape
|
batch_size, num_channels, height, width = feature_map.shape
|
||||||
projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
|
projected_feature_map = self.conditional_detr.model.input_projection(feature_map)
|
||||||
|
|
||||||
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
||||||
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
||||||
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
||||||
|
|
||||||
flattened_mask = mask.flatten(1)
|
flattened_mask = mask.flatten(1)
|
||||||
|
|
||||||
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
|
||||||
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
||||||
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.conditional_detr.model.encoder(
|
encoder_outputs = self.conditional_detr.model.encoder(
|
||||||
inputs_embeds=flattened_features,
|
inputs_embeds=flattened_features,
|
||||||
attention_mask=flattened_mask,
|
attention_mask=flattened_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1975,7 +2089,7 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
||||||
query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
|
||||||
batch_size, 1, 1
|
batch_size, 1, 1
|
||||||
)
|
)
|
||||||
@@ -1985,7 +2099,7 @@ class ConditionalDetrForSegmentation(ConditionalDetrPreTrainedModel):
|
|||||||
decoder_outputs = self.conditional_detr.model.decoder(
|
decoder_outputs = self.conditional_detr.model.decoder(
|
||||||
inputs_embeds=queries,
|
inputs_embeds=queries,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=flattened_mask,
|
encoder_attention_mask=flattened_mask,
|
||||||
|
|||||||
@@ -528,34 +528,79 @@ class DetrAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||||
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
|
||||||
return tensor if position_embeddings is None else tensor + position_embeddings
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
return tensor if object_queries is None else tensor + object_queries
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
key_value_position_embeddings: Optional[torch.Tensor] = None,
|
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
position_embeddings = kwargs.pop("position_ebmeddings", None)
|
||||||
|
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
|
||||||
|
)
|
||||||
|
spatial_position_embeddings = key_value_position_embeddings
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
batch_size, target_len, embed_dim = hidden_states.size()
|
batch_size, target_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
# add position embeddings to the hidden states before projecting to queries and keys
|
# add position embeddings to the hidden states before projecting to queries and keys
|
||||||
if position_embeddings is not None:
|
if object_queries is not None:
|
||||||
hidden_states_original = hidden_states
|
hidden_states_original = hidden_states
|
||||||
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
||||||
|
|
||||||
# add key-value position embeddings to the key value states
|
# add key-value position embeddings to the key value states
|
||||||
if key_value_position_embeddings is not None:
|
if spatial_position_embeddings is not None:
|
||||||
key_value_states_original = key_value_states
|
key_value_states_original = key_value_states
|
||||||
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
|
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
@@ -645,8 +690,9 @@ class DetrEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor = None,
|
object_queries: torch.Tensor = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -654,16 +700,33 @@ class DetrEncoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
|
Object queries (also called content embeddings), to be added to the hidden states.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, attn_weights = self.self_attn(
|
hidden_states, attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -723,11 +786,12 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
query_position_embeddings: Optional[torch.Tensor] = None,
|
query_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -735,8 +799,8 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*):
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object_queries that are added to the hidden states
|
||||||
in the cross-attention layer.
|
in the cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
position embeddings that are added to the queries and keys
|
||||||
@@ -750,12 +814,28 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -771,10 +851,10 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states, cross_attn_weights = self.encoder_attn(
|
hidden_states, cross_attn_weights = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
key_value_position_embeddings=position_embeddings,
|
spatial_position_embeddings=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -913,7 +993,7 @@ class DetrEncoder(DetrPreTrainedModel):
|
|||||||
|
|
||||||
Small tweak for DETR:
|
Small tweak for DETR:
|
||||||
|
|
||||||
- position_embeddings are added to the forward pass.
|
- object_queries are added to the forward pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: DetrConfig
|
config: DetrConfig
|
||||||
@@ -936,10 +1016,11 @@ class DetrEncoder(DetrPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -954,8 +1035,8 @@ class DetrEncoder(DetrPreTrainedModel):
|
|||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
Position embeddings that are added to the queries and keys in each self-attention layer.
|
Object queries that are added to the queries in each self-attention layer.
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
@@ -966,6 +1047,22 @@ class DetrEncoder(DetrPreTrainedModel):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -995,11 +1092,11 @@ class DetrEncoder(DetrPreTrainedModel):
|
|||||||
if to_drop:
|
if to_drop:
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
# we add position_embeddings as extra input to the encoder_layer
|
# we add object_queries as extra input to the encoder_layer
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1026,7 +1123,7 @@ class DetrDecoder(DetrPreTrainedModel):
|
|||||||
|
|
||||||
Some small tweaks for DETR:
|
Some small tweaks for DETR:
|
||||||
|
|
||||||
- position_embeddings and query_position_embeddings are added to the forward pass.
|
- object_queries and query_position_embeddings are added to the forward pass.
|
||||||
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1052,11 +1149,12 @@ class DetrDecoder(DetrPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
query_position_embeddings=None,
|
query_position_embeddings=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1080,10 +1178,11 @@ class DetrDecoder(DetrPreTrainedModel):
|
|||||||
- 1 for pixels that are real (i.e. **not masked**),
|
- 1 for pixels that are real (i.e. **not masked**),
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
- 0 for pixels that are padding (i.e. **masked**).
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
Object queries that are added to the queries and keys in each cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
||||||
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
, *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -1093,6 +1192,22 @@ class DetrDecoder(DetrPreTrainedModel):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1155,7 +1270,7 @@ class DetrDecoder(DetrPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@@ -1213,8 +1328,8 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
|
|
||||||
# Create backbone + positional encoding
|
# Create backbone + positional encoding
|
||||||
backbone = DetrConvEncoder(config)
|
backbone = DetrConvEncoder(config)
|
||||||
position_embeddings = build_position_encoding(config)
|
object_queries = build_position_encoding(config)
|
||||||
self.backbone = DetrConvModel(backbone, position_embeddings)
|
self.backbone = DetrConvModel(backbone, object_queries)
|
||||||
|
|
||||||
# Create projection layer
|
# Create projection layer
|
||||||
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
||||||
@@ -1298,7 +1413,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
||||||
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
# pixel_values should be of shape (batch_size, num_channels, height, width)
|
||||||
# pixel_mask should be of shape (batch_size, height, width)
|
# pixel_mask should be of shape (batch_size, height, width)
|
||||||
features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
|
features, object_queries_list = self.backbone(pixel_values, pixel_mask)
|
||||||
|
|
||||||
# get final feature map and downsampled mask
|
# get final feature map and downsampled mask
|
||||||
feature_map, mask = features[-1]
|
feature_map, mask = features[-1]
|
||||||
@@ -1312,7 +1427,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
||||||
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
||||||
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
||||||
|
|
||||||
flattened_mask = mask.flatten(1)
|
flattened_mask = mask.flatten(1)
|
||||||
|
|
||||||
@@ -1323,7 +1438,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=flattened_features,
|
inputs_embeds=flattened_features,
|
||||||
attention_mask=flattened_mask,
|
attention_mask=flattened_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1336,7 +1451,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
|
||||||
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
queries = torch.zeros_like(query_position_embeddings)
|
queries = torch.zeros_like(query_position_embeddings)
|
||||||
|
|
||||||
@@ -1344,7 +1459,7 @@ class DetrModel(DetrPreTrainedModel):
|
|||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
inputs_embeds=queries,
|
inputs_embeds=queries,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=flattened_mask,
|
encoder_attention_mask=flattened_mask,
|
||||||
@@ -1640,7 +1755,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
pixel_mask = torch.ones((batch_size, height, width), device=device)
|
||||||
|
|
||||||
# First, get list of feature maps and position embeddings
|
# First, get list of feature maps and position embeddings
|
||||||
features, position_embeddings_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
|
||||||
|
|
||||||
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
||||||
feature_map, mask = features[-1]
|
feature_map, mask = features[-1]
|
||||||
@@ -1650,7 +1765,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
||||||
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
||||||
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
||||||
|
|
||||||
flattened_mask = mask.flatten(1)
|
flattened_mask = mask.flatten(1)
|
||||||
|
|
||||||
@@ -1661,7 +1776,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
encoder_outputs = self.detr.model.encoder(
|
encoder_outputs = self.detr.model.encoder(
|
||||||
inputs_embeds=flattened_features,
|
inputs_embeds=flattened_features,
|
||||||
attention_mask=flattened_mask,
|
attention_mask=flattened_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1684,7 +1799,7 @@ class DetrForSegmentation(DetrPreTrainedModel):
|
|||||||
decoder_outputs = self.detr.model.decoder(
|
decoder_outputs = self.detr.model.decoder(
|
||||||
inputs_embeds=queries,
|
inputs_embeds=queries,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=flattened_mask,
|
encoder_attention_mask=flattened_mask,
|
||||||
|
|||||||
@@ -437,34 +437,79 @@ class DetrAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||||
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
|
||||||
return tensor if position_embeddings is None else tensor + position_embeddings
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
return tensor if object_queries is None else tensor + object_queries
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
key_value_position_embeddings: Optional[torch.Tensor] = None,
|
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
position_embeddings = kwargs.pop("position_ebmeddings", None)
|
||||||
|
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
|
||||||
|
)
|
||||||
|
spatial_position_embeddings = key_value_position_embeddings
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
batch_size, target_len, embed_dim = hidden_states.size()
|
batch_size, target_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
# add position embeddings to the hidden states before projecting to queries and keys
|
# add position embeddings to the hidden states before projecting to queries and keys
|
||||||
if position_embeddings is not None:
|
if object_queries is not None:
|
||||||
hidden_states_original = hidden_states
|
hidden_states_original = hidden_states
|
||||||
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
||||||
|
|
||||||
# add key-value position embeddings to the key value states
|
# add key-value position embeddings to the key value states
|
||||||
if key_value_position_embeddings is not None:
|
if spatial_position_embeddings is not None:
|
||||||
key_value_states_original = key_value_states
|
key_value_states_original = key_value_states
|
||||||
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
|
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
@@ -563,11 +608,12 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
query_position_embeddings: Optional[torch.Tensor] = None,
|
query_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -575,8 +621,8 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*):
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object_queries that are added to the hidden states
|
||||||
in the cross-attention layer.
|
in the cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
position embeddings that are added to the queries and keys
|
||||||
@@ -590,12 +636,28 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -611,10 +673,10 @@ class DetrDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
hidden_states, cross_attn_weights = self.encoder_attn(
|
hidden_states, cross_attn_weights = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
key_value_position_embeddings=position_embeddings,
|
spatial_position_embeddings=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -662,7 +724,7 @@ class DetrDecoder(nn.Module):
|
|||||||
|
|
||||||
Some small tweaks for DETR:
|
Some small tweaks for DETR:
|
||||||
|
|
||||||
- position_embeddings and query_position_embeddings are added to the forward pass.
|
- object_queries and query_position_embeddings are added to the forward pass.
|
||||||
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -687,11 +749,12 @@ class DetrDecoder(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
query_position_embeddings=None,
|
query_position_embeddings=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -715,7 +778,7 @@ class DetrDecoder(nn.Module):
|
|||||||
- 1 for pixels that are real (i.e. **not masked**),
|
- 1 for pixels that are real (i.e. **not masked**),
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
- 0 for pixels that are padding (i.e. **masked**).
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
||||||
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
||||||
@@ -728,6 +791,21 @@ class DetrDecoder(nn.Module):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -788,7 +866,7 @@ class DetrDecoder(nn.Module):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@@ -1438,23 +1516,23 @@ class MaskFormerTransformerModule(nn.Module):
|
|||||||
) -> DetrDecoderOutput:
|
) -> DetrDecoderOutput:
|
||||||
if self.input_projection is not None:
|
if self.input_projection is not None:
|
||||||
image_features = self.input_projection(image_features)
|
image_features = self.input_projection(image_features)
|
||||||
position_embeddings = self.position_embedder(image_features)
|
object_queries = self.position_embedder(image_features)
|
||||||
# repeat the queries "q c -> b q c"
|
# repeat the queries "q c -> b q c"
|
||||||
batch_size = image_features.shape[0]
|
batch_size = image_features.shape[0]
|
||||||
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
|
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
|
||||||
|
|
||||||
batch_size, num_channels, height, width = image_features.shape
|
batch_size, num_channels, height, width = image_features.shape
|
||||||
# rearrange both image_features and position_embeddings "b c h w -> b (h w) c"
|
# rearrange both image_features and object_queries "b c h w -> b (h w) c"
|
||||||
image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
|
|
||||||
decoder_output: DetrDecoderOutput = self.decoder(
|
decoder_output: DetrDecoderOutput = self.decoder(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=image_features,
|
encoder_hidden_states=image_features,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=queries_embeddings,
|
query_position_embeddings=queries_embeddings,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -469,34 +469,79 @@ class TableTransformerAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
||||||
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
|
||||||
return tensor if position_embeddings is None else tensor + position_embeddings
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
return tensor if object_queries is None else tensor + object_queries
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
key_value_position_embeddings: Optional[torch.Tensor] = None,
|
spatial_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
position_embeddings = kwargs.pop("position_ebmeddings", None)
|
||||||
|
key_value_position_embeddings = kwargs.pop("key_value_position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None and spatial_position_embeddings is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both key_value_position_embeddings and spatial_position_embeddings. Please use just spatial_position_embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
|
if key_value_position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"key_value_position_embeddings has been deprecated and will be removed in v4.34. Please use spatial_position_embeddings instead"
|
||||||
|
)
|
||||||
|
spatial_position_embeddings = key_value_position_embeddings
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
batch_size, target_len, embed_dim = hidden_states.size()
|
batch_size, target_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
# add position embeddings to the hidden states before projecting to queries and keys
|
# add position embeddings to the hidden states before projecting to queries and keys
|
||||||
if position_embeddings is not None:
|
if object_queries is not None:
|
||||||
hidden_states_original = hidden_states
|
hidden_states_original = hidden_states
|
||||||
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
hidden_states = self.with_pos_embed(hidden_states, object_queries)
|
||||||
|
|
||||||
# add key-value position embeddings to the key value states
|
# add key-value position embeddings to the key value states
|
||||||
if key_value_position_embeddings is not None:
|
if spatial_position_embeddings is not None:
|
||||||
key_value_states_original = key_value_states
|
key_value_states_original = key_value_states
|
||||||
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
|
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
@@ -587,7 +632,7 @@ class TableTransformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor = None,
|
object_queries: torch.Tensor = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -596,7 +641,7 @@ class TableTransformerEncoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
|
object_queries (`torch.FloatTensor`, *optional*): object queries, to be added to hidden_states.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -607,7 +652,7 @@ class TableTransformerEncoderLayer(nn.Module):
|
|||||||
hidden_states, attn_weights = self.self_attn(
|
hidden_states, attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -668,7 +713,7 @@ class TableTransformerDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
object_queries: Optional[torch.Tensor] = None,
|
||||||
query_position_embeddings: Optional[torch.Tensor] = None,
|
query_position_embeddings: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -680,11 +725,11 @@ class TableTransformerDecoderLayer(nn.Module):
|
|||||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||||
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
||||||
values.
|
values.
|
||||||
position_embeddings (`torch.FloatTensor`, *optional*):
|
object_queries (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object queries that are added to the queries and keys
|
||||||
in the cross-attention layer.
|
in the cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
query_position_embeddings (`torch.FloatTensor`, *optional*):
|
||||||
position embeddings that are added to the queries and keys
|
object queries that are added to the queries and keys
|
||||||
in the self-attention layer.
|
in the self-attention layer.
|
||||||
encoder_hidden_states (`torch.FloatTensor`):
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
@@ -701,7 +746,7 @@ class TableTransformerDecoderLayer(nn.Module):
|
|||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -717,10 +762,10 @@ class TableTransformerDecoderLayer(nn.Module):
|
|||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
hidden_states, cross_attn_weights = self.encoder_attn(
|
hidden_states, cross_attn_weights = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
position_embeddings=query_position_embeddings,
|
object_queries=query_position_embeddings,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
key_value_position_embeddings=position_embeddings,
|
spatial_position_embeddings=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -854,7 +899,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
|
|||||||
|
|
||||||
Small tweak for Table Transformer:
|
Small tweak for Table Transformer:
|
||||||
|
|
||||||
- position_embeddings are added to the forward pass.
|
- object_queries are added to the forward pass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: TableTransformerConfig
|
config: TableTransformerConfig
|
||||||
@@ -877,7 +922,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -895,7 +940,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
|
|||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
Position embeddings that are added to the queries and keys in each self-attention layer.
|
Position embeddings that are added to the queries and keys in each self-attention layer.
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
@@ -936,11 +981,11 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
|
|||||||
if to_drop:
|
if to_drop:
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
# we add position_embeddings as extra input to the encoder_layer
|
# we add object_queries as extra input to the encoder_layer
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -970,7 +1015,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
|
|||||||
|
|
||||||
Some small tweaks for TABLE_TRANSFORMER:
|
Some small tweaks for TABLE_TRANSFORMER:
|
||||||
|
|
||||||
- position_embeddings and query_position_embeddings are added to the forward pass.
|
- object_queries and query_position_embeddings are added to the forward pass.
|
||||||
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -996,11 +1041,12 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings=None,
|
object_queries=None,
|
||||||
query_position_embeddings=None,
|
query_position_embeddings=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1024,10 +1070,11 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
|
|||||||
- 1 for pixels that are real (i.e. **not masked**),
|
- 1 for pixels that are real (i.e. **not masked**),
|
||||||
- 0 for pixels that are padding (i.e. **masked**).
|
- 0 for pixels that are padding (i.e. **masked**).
|
||||||
|
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Position embeddings that are added to the queries and keys in each cross-attention layer.
|
Object queries that are added to the queries and keys in each cross-attention layer.
|
||||||
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
||||||
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
|
, *optional*): Position embeddings that are added to the values and keys in each self-attention layer.
|
||||||
|
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
returned tensors for more detail.
|
returned tensors for more detail.
|
||||||
@@ -1037,6 +1084,22 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
position_embeddings = kwargs.pop("position_embeddings", None)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError(f"Unexpected arguments {kwargs.keys()}")
|
||||||
|
|
||||||
|
if position_embeddings is not None and object_queries is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify both position_embeddings and object_queries. Please use just object_queries"
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
|
||||||
|
)
|
||||||
|
object_queries = position_embeddings
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -1099,7 +1162,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@@ -1158,8 +1221,8 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
|
|||||||
|
|
||||||
# Create backbone + positional encoding
|
# Create backbone + positional encoding
|
||||||
backbone = TableTransformerConvEncoder(config)
|
backbone = TableTransformerConvEncoder(config)
|
||||||
position_embeddings = build_position_encoding(config)
|
object_queries = build_position_encoding(config)
|
||||||
self.backbone = TableTransformerConvModel(backbone, position_embeddings)
|
self.backbone = TableTransformerConvModel(backbone, object_queries)
|
||||||
|
|
||||||
# Create projection layer
|
# Create projection layer
|
||||||
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
|
||||||
@@ -1254,21 +1317,21 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
|
|||||||
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
||||||
projected_feature_map = self.input_projection(feature_map)
|
projected_feature_map = self.input_projection(feature_map)
|
||||||
|
|
||||||
# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
# Third, flatten the feature map + object queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
|
||||||
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
|
||||||
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
|
||||||
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
object_queries = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
|
||||||
|
|
||||||
flattened_mask = mask.flatten(1)
|
flattened_mask = mask.flatten(1)
|
||||||
|
|
||||||
# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
|
# Fourth, sent flattened_features + flattened_mask + object queries through encoder
|
||||||
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
|
||||||
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=flattened_features,
|
inputs_embeds=flattened_features,
|
||||||
attention_mask=flattened_mask,
|
attention_mask=flattened_mask,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1281,7 +1344,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
|
|||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
|
# Fifth, sent query embeddings + object queries through the decoder (which is conditioned on the encoder output)
|
||||||
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||||
queries = torch.zeros_like(query_position_embeddings)
|
queries = torch.zeros_like(query_position_embeddings)
|
||||||
|
|
||||||
@@ -1289,7 +1352,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
|
|||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
inputs_embeds=queries,
|
inputs_embeds=queries,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=position_embeddings,
|
object_queries=object_queries,
|
||||||
query_position_embeddings=query_position_embeddings,
|
query_position_embeddings=query_position_embeddings,
|
||||||
encoder_hidden_states=encoder_outputs[0],
|
encoder_hidden_states=encoder_outputs[0],
|
||||||
encoder_attention_mask=flattened_mask,
|
encoder_attention_mask=flattened_mask,
|
||||||
|
|||||||
@@ -606,7 +606,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
|||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
expected_number_of_segments = 5
|
expected_number_of_segments = 5
|
||||||
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096}
|
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
|
||||||
|
|
||||||
number_of_unique_segments = len(torch.unique(results["segmentation"]))
|
number_of_unique_segments = len(torch.unique(results["segmentation"]))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
Reference in New Issue
Block a user