Fix TVPModelTest (#27695)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -606,12 +606,15 @@ class TvpFrameDownPadPrompter(nn.Module):
|
|||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
if self.visual_prompter_apply != "add":
|
if self.visual_prompter_apply != "add":
|
||||||
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
|
visual_prompt_mask = torch.ones(
|
||||||
|
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
||||||
|
)
|
||||||
visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
|
visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
|
||||||
pixel_values *= visual_prompt_mask
|
pixel_values *= visual_prompt_mask
|
||||||
if self.visual_prompter_apply != "remove":
|
if self.visual_prompter_apply != "remove":
|
||||||
prompt = torch.zeros(
|
prompt = torch.zeros(
|
||||||
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size]
|
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
|
||||||
|
device=pixel_values.device,
|
||||||
)
|
)
|
||||||
start_point = self.max_img_size - self.visual_prompt_size
|
start_point = self.max_img_size - self.visual_prompt_size
|
||||||
prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
|
prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
|
||||||
@@ -667,10 +670,12 @@ class TvpFramePadPrompter(nn.Module):
|
|||||||
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
||||||
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
||||||
if self.visual_prompter_apply in ("replace", "remove"):
|
if self.visual_prompter_apply in ("replace", "remove"):
|
||||||
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
|
visual_prompt_mask = torch.ones(
|
||||||
|
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
||||||
|
)
|
||||||
pixel_values *= visual_prompt_mask
|
pixel_values *= visual_prompt_mask
|
||||||
if self.visual_prompter_apply in ("replace", "add"):
|
if self.visual_prompter_apply in ("replace", "add"):
|
||||||
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size)
|
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
|
||||||
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
||||||
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
||||||
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
||||||
|
|||||||
@@ -176,6 +176,9 @@ class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Enable this once this model gets more usage
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TVPModelTester(self)
|
self.model_tester = TVPModelTester(self)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user