Fix TVLT (torch device issue) (#21710)
* fix tvlt ci * fix tvlt ci * fix tvlt ci --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -153,7 +153,7 @@ def generate_pixel_mask_noise(pixel_values, pixel_mask=None, mask_ratio=0.75):
|
|||||||
"""Generate noise for audio masking."""
|
"""Generate noise for audio masking."""
|
||||||
|
|
||||||
batch_size, seq_len = pixel_values.shape[:2]
|
batch_size, seq_len = pixel_values.shape[:2]
|
||||||
noise = torch.rand((batch_size, seq_len)) # noise in [0, 1]
|
noise = torch.rand((batch_size, seq_len), device=pixel_values.device) # noise in [0, 1]
|
||||||
len_keep = int(seq_len * (1 - mask_ratio))
|
len_keep = int(seq_len * (1 - mask_ratio))
|
||||||
return noise, len_keep
|
return noise, len_keep
|
||||||
|
|
||||||
@@ -165,10 +165,13 @@ def generate_audio_mask_noise(audio_values, audio_mask=None, mask_ratio=0.75, ma
|
|||||||
if mask_type == "frame-level":
|
if mask_type == "frame-level":
|
||||||
num_time_patches = seq_len // freq_len
|
num_time_patches = seq_len // freq_len
|
||||||
noise = (
|
noise = (
|
||||||
torch.rand(batch_size, num_time_patches).unsqueeze(-1).repeat(1, 1, freq_len).view(batch_size, seq_len)
|
torch.rand(batch_size, num_time_patches, device=audio_values.device)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.repeat(1, 1, freq_len)
|
||||||
|
.view(batch_size, seq_len)
|
||||||
) # noise in [0, 1]
|
) # noise in [0, 1]
|
||||||
elif mask_type == "patch-level":
|
elif mask_type == "patch-level":
|
||||||
noise = torch.rand(batch_size, seq_len) # noise in [0, 1]
|
noise = torch.rand(batch_size, seq_len, device=audio_values.device) # noise in [0, 1]
|
||||||
len_keep = int(seq_len * (1 - mask_ratio))
|
len_keep = int(seq_len * (1 - mask_ratio))
|
||||||
return noise, len_keep
|
return noise, len_keep
|
||||||
|
|
||||||
|
|||||||
@@ -590,7 +590,7 @@ class TvltModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]])
|
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]], device=torch_device)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(outputs.last_hidden_state[:, :2, :2], expected_last_hidden_state_slice, atol=1e-4)
|
torch.allclose(outputs.last_hidden_state[:, :2, :2], expected_last_hidden_state_slice, atol=1e-4)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user