interpolation added for TVP. (#30863)
* Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * docstring and argument fix * doc fixes and test case fix suggested in review. * varibale typo fix * styling and name fixes for padding interpolation flag.
This commit is contained in:
@@ -193,34 +193,81 @@ class TvpVisualInputEmbedding(nn.Module):
|
|||||||
self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
|
self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
|
||||||
|
self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings
|
||||||
|
|
||||||
def add_2d_positional_embeddings(self, grid):
|
def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
|
||||||
|
resolution images (high resolution videos).
|
||||||
|
|
||||||
|
"""
|
||||||
|
h0 = w0 = 1
|
||||||
|
# if height dimension is to be interpolated
|
||||||
|
if height > self.max_grid_row_position_embeddings:
|
||||||
|
h0 = height / self.max_grid_row_position_embeddings
|
||||||
|
# if width dimension is to be interpolated
|
||||||
|
if width > self.max_grid_col_position_embeddings:
|
||||||
|
w0 = width / self.max_grid_col_position_embeddings
|
||||||
|
embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
|
||||||
|
embedding = nn.functional.interpolate(
|
||||||
|
embedding,
|
||||||
|
scale_factor=(h0, w0),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
grid: (batch_size, height, width, hidden_dim)
|
grid: (batch_size, height, width, hidden_dim)
|
||||||
|
interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
Returns:
|
Returns:
|
||||||
grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
|
grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
|
||||||
"""
|
"""
|
||||||
batch_size, height, width, hidden_dim = grid.shape
|
batch_size, height, width, hidden_dim = grid.shape
|
||||||
|
|
||||||
# add row-wise position embeddings
|
# add row-wise position embeddings
|
||||||
row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, )
|
# (height, )
|
||||||
row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim)
|
row_height = min(self.max_grid_row_position_embeddings, height)
|
||||||
row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim)
|
row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
|
||||||
grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically
|
# (height, hidden_dim)
|
||||||
|
row_position_embeddings = self.row_position_embeddings(row_position_ids)
|
||||||
|
row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
|
||||||
|
# (batch_size, height, 1, hidden_dim)
|
||||||
|
row_position_embeddings = row_position_embeddings.view(*row_shape)
|
||||||
|
|
||||||
# add column-wise position embeddings
|
# add column-wise position embeddings
|
||||||
col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, )
|
row_width = min(self.max_grid_col_position_embeddings, width)
|
||||||
col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim)
|
col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
|
||||||
col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim)
|
# (width, hidden_dim)
|
||||||
return grid + col_position_embeddings.view(*col_shape) # broadcast automatically
|
col_position_embeddings = self.col_position_embeddings(col_position_ids)
|
||||||
|
col_shape = (batch_size, 1, row_width, hidden_dim)
|
||||||
|
# (batch_size, 1, width, hidden_dim)
|
||||||
|
col_position_embeddings = col_position_embeddings.view(*col_shape)
|
||||||
|
# (batch_size, height, width, hidden_dim)
|
||||||
|
positional_embeddings = row_position_embeddings + col_position_embeddings
|
||||||
|
|
||||||
def forward(self, grid):
|
# This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings
|
||||||
|
if interpolate_pos_encoding and (
|
||||||
|
height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
|
||||||
|
):
|
||||||
|
grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
|
||||||
|
else:
|
||||||
|
grid = grid + positional_embeddings
|
||||||
|
return grid
|
||||||
|
|
||||||
|
def forward(self, grid, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
grid: Array of shape (batch_size, num_frames, height, width, num_channels).
|
grid: Array of shape (batch_size, num_frames, height, width, num_channels).
|
||||||
It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
|
It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
|
||||||
num_frames can be 1
|
num_frames can be 1
|
||||||
|
interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
|
embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
|
||||||
@@ -229,7 +276,7 @@ class TvpVisualInputEmbedding(nn.Module):
|
|||||||
batch_size, num_frames, height, width, num_channels = grid.shape
|
batch_size, num_frames, height, width, num_channels = grid.shape
|
||||||
# temporal mean pooling, (batch_size, height, width, hidden_size)
|
# temporal mean pooling, (batch_size, height, width, hidden_size)
|
||||||
grid = grid.mean(1)
|
grid = grid.mean(1)
|
||||||
grid = self.add_2d_positional_embeddings(grid)
|
grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
# image token sequence, (batch_size, height*width, num_channels)
|
# image token sequence, (batch_size, height*width, num_channels)
|
||||||
visual_tokens = grid.view(batch_size, -1, num_channels)
|
visual_tokens = grid.view(batch_size, -1, num_channels)
|
||||||
visual_tokens_shape = visual_tokens.shape[:-1]
|
visual_tokens_shape = visual_tokens.shape[:-1]
|
||||||
@@ -586,6 +633,9 @@ TVP_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained image pad prompter encodings and positional encodings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -639,7 +689,6 @@ class TvpFramePadPrompter(nn.Module):
|
|||||||
self.num_frames = config.num_frames
|
self.num_frames = config.num_frames
|
||||||
self.max_img_size = config.max_img_size
|
self.max_img_size = config.max_img_size
|
||||||
self.visual_prompter_apply = config.visual_prompter_apply
|
self.visual_prompter_apply = config.visual_prompter_apply
|
||||||
|
|
||||||
self.base_size = config.max_img_size - config.visual_prompt_size * 2
|
self.base_size = config.max_img_size - config.visual_prompt_size * 2
|
||||||
self.pad_up = nn.Parameter(
|
self.pad_up = nn.Parameter(
|
||||||
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
|
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
|
||||||
@@ -670,19 +719,49 @@ class TvpFramePadPrompter(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
|
||||||
|
resolution images (high resolution videos).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# creates scale factor from height and width of original image wrt to the config.max_img_size
|
||||||
|
h0, w0 = height / self.max_img_size, width / self.max_img_size
|
||||||
|
|
||||||
|
batch, num_frames, channels, prompt_height, prompt_width = prompt.shape
|
||||||
|
|
||||||
|
# reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
|
||||||
|
prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
|
||||||
|
prompt = nn.functional.interpolate(
|
||||||
|
prompt,
|
||||||
|
scale_factor=(h0, w0),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
# reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
|
||||||
|
prompt = prompt.reshape(batch, num_frames, channels, height, width)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
|
||||||
|
height, width = (
|
||||||
|
(pixel_values.shape[-2], pixel_values.shape[-1])
|
||||||
|
if interpolate_pad_encoding
|
||||||
|
else (self.max_img_size, self.max_img_size)
|
||||||
|
)
|
||||||
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
||||||
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
||||||
if self.visual_prompter_apply in ("replace", "remove"):
|
if self.visual_prompter_apply in ("replace", "remove"):
|
||||||
visual_prompt_mask = torch.ones(
|
visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
|
||||||
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
|
||||||
)
|
|
||||||
pixel_values *= visual_prompt_mask
|
pixel_values *= visual_prompt_mask
|
||||||
if self.visual_prompter_apply in ("replace", "add"):
|
if self.visual_prompter_apply in ("replace", "add"):
|
||||||
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
|
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
|
||||||
|
|
||||||
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
||||||
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
||||||
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
||||||
|
if interpolate_pad_encoding:
|
||||||
|
prompt = self.interpolate_pad_encoding(prompt, height, width)
|
||||||
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
|
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
@@ -738,6 +817,7 @@ class TvpModel(TvpPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -756,13 +836,17 @@ class TvpModel(TvpPreTrainedModel):
|
|||||||
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
|
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
|
# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
|
||||||
pixel_values = self.vision_model(self.visual_prompter(pixel_values))
|
pixel_values = self.vision_model(
|
||||||
|
self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
|
||||||
|
)
|
||||||
# (batch_size, sequence_length, hidden_size)
|
# (batch_size, sequence_length, hidden_size)
|
||||||
text_embedding_output = self.embeddings(input_ids=input_ids)
|
text_embedding_output = self.embeddings(input_ids=input_ids)
|
||||||
# (batch_size, visual_sequence_length, hidden_size)
|
# (batch_size, visual_sequence_length, hidden_size)
|
||||||
visual_embedding_output = self.visual_embeddings(pixel_values)
|
visual_embedding_output = self.visual_embeddings(
|
||||||
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# (batch_size, visual_sequence_length)
|
# (batch_size, visual_sequence_length)
|
||||||
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
|
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
|
||||||
@@ -791,7 +875,6 @@ class TvpModel(TvpPreTrainedModel):
|
|||||||
pooled_output = self.dropout(pooled_output)
|
pooled_output = self.dropout(pooled_output)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
@@ -841,6 +924,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
|
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
|
||||||
@@ -869,9 +953,9 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
pooler_output = outputs[1]
|
pooler_output = outputs[1]
|
||||||
|
|
||||||
logits = self.video_grounding_head(pooler_output)
|
logits = self.video_grounding_head(pooler_output)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
@@ -884,7 +968,6 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
|||||||
+ self.config.distance_loss_weight * loss_dict["distance"]
|
+ self.config.distance_loss_weight * loss_dict["distance"]
|
||||||
+ self.config.duration_loss_weight * loss_dict["duration"]
|
+ self.config.duration_loss_weight * loss_dict["duration"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
outputs = (logits,) + outputs[2:]
|
outputs = (logits,) + outputs[2:]
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
|
|||||||
@@ -256,7 +256,7 @@ def prepare_img():
|
|||||||
class TvpModelIntegrationTests(unittest.TestCase):
|
class TvpModelIntegrationTests(unittest.TestCase):
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_image_processor(self):
|
def default_image_processor(self):
|
||||||
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None
|
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp")
|
||||||
|
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||||
@@ -297,3 +297,41 @@ class TvpModelIntegrationTests(unittest.TestCase):
|
|||||||
assert outputs.logits.shape == expected_shape
|
assert outputs.logits.shape == expected_shape
|
||||||
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
|
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
def test_interpolate_inference_no_head(self):
|
||||||
|
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img() # 480X640
|
||||||
|
encoding = image_processor(
|
||||||
|
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor([[1, 2]])
|
||||||
|
attention_mask = torch.tensor([[1, 1]])
|
||||||
|
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||||
|
encoding.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**encoding, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 1212, 128))
|
||||||
|
assert outputs.last_hidden_state.shape == expected_shape
|
||||||
|
|
||||||
|
def test_interpolate_inference_with_head(self):
|
||||||
|
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img() # 480X640
|
||||||
|
encoding = image_processor(
|
||||||
|
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
|
||||||
|
)
|
||||||
|
input_ids = torch.tensor([[1, 2]])
|
||||||
|
attention_mask = torch.tensor([[1, 1]])
|
||||||
|
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||||
|
encoding.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**encoding, interpolate_pos_encoding=True, output_hidden_states=True)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 1212, 128))
|
||||||
|
assert outputs.hidden_states[-1].shape == expected_shape
|
||||||
|
|||||||
Reference in New Issue
Block a user