Fix TVPModelTest (#27695)

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-11-24 19:47:50 +01:00
committed by GitHub
parent 29c94808ea
commit 35551f9a0f
2 changed files with 12 additions and 4 deletions

View File

@@ -606,12 +606,15 @@ class TvpFrameDownPadPrompter(nn.Module):
def forward(self, pixel_values):
if self.visual_prompter_apply != "add":
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply != "remove":
prompt = torch.zeros(
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size]
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
device=pixel_values.device,
)
start_point = self.max_img_size - self.visual_prompt_size
prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
@@ -667,10 +670,12 @@ class TvpFramePadPrompter(nn.Module):
if self.visual_prompter_apply not in ("add", "remove", "replace"):
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
if self.visual_prompter_apply in ("replace", "remove"):
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply in ("replace", "add"):
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size)
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
prompt = torch.cat(pixel_values.size(0) * [prompt])

View File

@@ -176,6 +176,9 @@ class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else {}
)
# TODO: Enable this once this model gets more usage
test_torchscript = False
def setUp(self):
self.model_tester = TVPModelTester(self)