Fix TVPModelTest (#27695)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -606,12 +606,15 @@ class TvpFrameDownPadPrompter(nn.Module):
|
||||
|
||||
def forward(self, pixel_values):
|
||||
if self.visual_prompter_apply != "add":
|
||||
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
|
||||
visual_prompt_mask = torch.ones(
|
||||
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
||||
)
|
||||
visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
|
||||
pixel_values *= visual_prompt_mask
|
||||
if self.visual_prompter_apply != "remove":
|
||||
prompt = torch.zeros(
|
||||
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size]
|
||||
[pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
|
||||
device=pixel_values.device,
|
||||
)
|
||||
start_point = self.max_img_size - self.visual_prompt_size
|
||||
prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
|
||||
@@ -667,10 +670,12 @@ class TvpFramePadPrompter(nn.Module):
|
||||
if self.visual_prompter_apply not in ("add", "remove", "replace"):
|
||||
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
|
||||
if self.visual_prompter_apply in ("replace", "remove"):
|
||||
visual_prompt_mask = torch.ones([self.max_img_size, self.max_img_size], dtype=pixel_values.dtype)
|
||||
visual_prompt_mask = torch.ones(
|
||||
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
|
||||
)
|
||||
pixel_values *= visual_prompt_mask
|
||||
if self.visual_prompter_apply in ("replace", "add"):
|
||||
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size)
|
||||
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
|
||||
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
|
||||
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
|
||||
prompt = torch.cat(pixel_values.size(0) * [prompt])
|
||||
|
||||
@@ -176,6 +176,9 @@ class TVPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else {}
|
||||
)
|
||||
|
||||
# TODO: Enable this once this model gets more usage
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TVPModelTester(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user