interpolation added for TVP. (#30863)
* Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * docstring and argument fix * doc fixes and test case fix suggested in review. * varibale typo fix * styling and name fixes for padding interpolation flag.
This commit is contained in:
@@ -193,34 +193,81 @@ class TvpVisualInputEmbedding(nn.Module):
|
||||
self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
|
||||
self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings
|
||||
|
||||
def add_2d_positional_embeddings(self, grid):
|
||||
def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
|
||||
resolution images (high resolution videos).
|
||||
|
||||
"""
|
||||
h0 = w0 = 1
|
||||
# if height dimension is to be interpolated
|
||||
if height > self.max_grid_row_position_embeddings:
|
||||
h0 = height / self.max_grid_row_position_embeddings
|
||||
# if width dimension is to be interpolated
|
||||
if width > self.max_grid_col_position_embeddings:
|
||||
w0 = width / self.max_grid_col_position_embeddings
|
||||
embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
|
||||
embedding = nn.functional.interpolate(
|
||||
embedding,
|
||||
scale_factor=(h0, w0),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
|
||||
return embedding
|
||||
|
||||
def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
grid: (batch_size, height, width, hidden_dim)
|
||||
interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
Returns:
|
||||
grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
|
||||
"""
|
||||
batch_size, height, width, hidden_dim = grid.shape
|
||||
|
||||
# add row-wise position embeddings
|
||||
row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, )
|
||||
row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim)
|
||||
row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim)
|
||||
grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically
|
||||
# (height, )
|
||||
row_height = min(self.max_grid_row_position_embeddings, height)
|
||||
row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
|
||||
# (height, hidden_dim)
|
||||
row_position_embeddings = self.row_position_embeddings(row_position_ids)
|
||||
row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
|
||||
# (batch_size, height, 1, hidden_dim)
|
||||
row_position_embeddings = row_position_embeddings.view(*row_shape)
|
||||
|
||||
# add column-wise position embeddings
|
||||
col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, )
|
||||
col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim)
|
||||
col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim)
|
||||
return grid + col_position_embeddings.view(*col_shape) # broadcast automatically
|
||||
row_width = min(self.max_grid_col_position_embeddings, width)
|
||||
col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
|
||||
# (width, hidden_dim)
|
||||
col_position_embeddings = self.col_position_embeddings(col_position_ids)
|
||||
col_shape = (batch_size, 1, row_width, hidden_dim)
|
||||
# (batch_size, 1, width, hidden_dim)
|
||||
col_position_embeddings = col_position_embeddings.view(*col_shape)
|
||||
# (batch_size, height, width, hidden_dim)
|
||||
positional_embeddings = row_position_embeddings + col_position_embeddings
|
||||
|
||||
def forward(self, grid):
|
||||
# This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings
|
||||
if interpolate_pos_encoding and (
|
||||
height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
|
||||
):
|
||||
grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
|
||||
else:
|
||||
grid = grid + positional_embeddings
|
||||
return grid
|
||||
|
||||
def forward(self, grid, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
grid: Array of shape (batch_size, num_frames, height, width, num_channels).
|
||||
It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
|
||||
num_frames can be 1
|
||||
interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
|
||||
Returns:
|
||||
embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
|
||||
@@ -229,7 +276,7 @@ class TvpVisualInputEmbedding(nn.Module):
|
||||
batch_size, num_frames, height, width, num_channels = grid.shape
|
||||
# temporal mean pooling, (batch_size, height, width, hidden_size)
|
||||
grid = grid.mean(1)
|
||||
grid = self.add_2d_positional_embeddings(grid)
|
||||
grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
# image token sequence, (batch_size, height*width, num_channels)
|
||||
visual_tokens = grid.view(batch_size, -1, num_channels)
|
||||
visual_tokens_shape = visual_tokens.shape[:-1]
|
||||
@@ -586,6 +633,9 @@ TVP_INPUTS_DOCSTRING = r"""
|
||||
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained image pad prompter encodings and positional encodings.
|
||||
"""
|
||||
|
||||
|
||||
@@ -639,7 +689,6 @@ class TvpFramePadPrompter(nn.Module):
|
||||
self.num_frames = config.num_frames
|
||||
self.max_img_size = config.max_img_size
|
||||
self.visual_prompter_apply = config.visual_prompter_apply
|
||||
|
||||
self.base_size = config.max_img_size - config.visual_prompt_size * 2
|
||||
self.pad_up = nn.Parameter(
|
||||
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
|
||||
@@ -670,19 +719,49 @@ class TvpFramePadPrompter(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
|
||||
resolution images (high resolution videos).
|
||||
|
||||
"""
|
||||
|
||||
# creates scale factor from height and width of original image wrt to the config.max_img_size
|
||||
h0, w0 = height / self.max_img_size, width / self.max_img_size
|
||||
|
||||
batch, num_frames, channels, prompt_height, prompt_width = prompt.shape
|
||||
|
||||
# reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
|
||||
prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
|
||||
prompt = nn.functional.interpolate(
|
||||
prompt,
|
||||
scale_factor=(h0, w0),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
# reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
|
||||
prompt = prompt.reshape(batch, num_frames, channels, height, width)
|
||||
return prompt
|
||||
|
||||
def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
|
||||
height, width = (
|
||||
(pixel_values.shape[-2], pixel_values.shape[-1])
|
||||
if interpolate_pad_encoding
|
||||
else (self.max_img_size, self.max_img_size)
|
||||
)
|
||||
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
||||
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
||||
if self.visual_prompter_apply in ("replace", "remove"):
|
||||
visual_prompt_mask = torch.ones(
|
||||
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
||||
)
|
||||
visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
pixel_values *= visual_prompt_mask
|
||||
if self.visual_prompter_apply in ("replace", "add"):
|
||||
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
|
||||
|
||||
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
||||
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
||||
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
||||
if interpolate_pad_encoding:
|
||||
prompt = self.interpolate_pad_encoding(prompt, height, width)
|
||||
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
|
||||
return pixel_values
|
||||
|
||||
@@ -738,6 +817,7 @@ class TvpModel(TvpPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@@ -756,13 +836,17 @@ class TvpModel(TvpPreTrainedModel):
|
||||
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
|
||||
pixel_values = self.vision_model(self.visual_prompter(pixel_values))
|
||||
pixel_values = self.vision_model(
|
||||
self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
|
||||
)
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
text_embedding_output = self.embeddings(input_ids=input_ids)
|
||||
# (batch_size, visual_sequence_length, hidden_size)
|
||||
visual_embedding_output = self.visual_embeddings(pixel_values)
|
||||
visual_embedding_output = self.visual_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# (batch_size, visual_sequence_length)
|
||||
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
|
||||
@@ -791,7 +875,6 @@ class TvpModel(TvpPreTrainedModel):
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@@ -841,6 +924,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
|
||||
@@ -869,9 +953,9 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
pooler_output = outputs[1]
|
||||
|
||||
logits = self.video_grounding_head(pooler_output)
|
||||
|
||||
loss = None
|
||||
@@ -884,7 +968,6 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
|
||||
+ self.config.distance_loss_weight * loss_dict["distance"]
|
||||
+ self.config.duration_loss_weight * loss_dict["duration"]
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (logits,) + outputs[2:]
|
||||
if loss is not None:
|
||||
|
||||
@@ -256,7 +256,7 @@ def prepare_img():
|
||||
class TvpModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None
|
||||
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp")
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||
@@ -297,3 +297,41 @@ class TvpModelIntegrationTests(unittest.TestCase):
|
||||
assert outputs.logits.shape == expected_shape
|
||||
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))
|
||||
|
||||
def test_interpolate_inference_no_head(self):
|
||||
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img() # 480X640
|
||||
encoding = image_processor(
|
||||
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
|
||||
)
|
||||
input_ids = torch.tensor([[1, 2]])
|
||||
attention_mask = torch.tensor([[1, 1]])
|
||||
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||
encoding.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding, interpolate_pos_encoding=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1212, 128))
|
||||
assert outputs.last_hidden_state.shape == expected_shape
|
||||
|
||||
def test_interpolate_inference_with_head(self):
|
||||
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img() # 480X640
|
||||
encoding = image_processor(
|
||||
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
|
||||
)
|
||||
input_ids = torch.tensor([[1, 2]])
|
||||
attention_mask = torch.tensor([[1, 1]])
|
||||
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||
encoding.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding, interpolate_pos_encoding=True, output_hidden_states=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1212, 128))
|
||||
assert outputs.hidden_states[-1].shape == expected_shape
|
||||
|
||||
Reference in New Issue
Block a user