From 03aaac35021dead4fb7ad354fe9c986d16869f03 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 21 Feb 2023 11:37:49 +0100 Subject: [PATCH] Fix TVLT (torch device issue) (#21710) * fix tvlt ci * fix tvlt ci * fix tvlt ci --------- Co-authored-by: ydshieh --- src/transformers/models/tvlt/modeling_tvlt.py | 9 ++++++--- tests/models/tvlt/test_modeling_tvlt.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 10250a1d6c..3725c5e772 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -153,7 +153,7 @@ def generate_pixel_mask_noise(pixel_values, pixel_mask=None, mask_ratio=0.75): """Generate noise for audio masking.""" batch_size, seq_len = pixel_values.shape[:2] - noise = torch.rand((batch_size, seq_len)) # noise in [0, 1] + noise = torch.rand((batch_size, seq_len), device=pixel_values.device) # noise in [0, 1] len_keep = int(seq_len * (1 - mask_ratio)) return noise, len_keep @@ -165,10 +165,13 @@ def generate_audio_mask_noise(audio_values, audio_mask=None, mask_ratio=0.75, ma if mask_type == "frame-level": num_time_patches = seq_len // freq_len noise = ( - torch.rand(batch_size, num_time_patches).unsqueeze(-1).repeat(1, 1, freq_len).view(batch_size, seq_len) + torch.rand(batch_size, num_time_patches, device=audio_values.device) + .unsqueeze(-1) + .repeat(1, 1, freq_len) + .view(batch_size, seq_len) ) # noise in [0, 1] elif mask_type == "patch-level": - noise = torch.rand(batch_size, seq_len) # noise in [0, 1] + noise = torch.rand(batch_size, seq_len, device=audio_values.device) # noise in [0, 1] len_keep = int(seq_len * (1 - mask_ratio)) return noise, len_keep diff --git a/tests/models/tvlt/test_modeling_tvlt.py b/tests/models/tvlt/test_modeling_tvlt.py index 4d5877c438..0f3d5ab68a 100644 --- a/tests/models/tvlt/test_modeling_tvlt.py +++ b/tests/models/tvlt/test_modeling_tvlt.py @@ -590,7 +590,7 @@ class TvltModelIntegrationTest(unittest.TestCase): outputs = model(**inputs) # verify the logits - expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]]) + expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]], device=torch_device) self.assertTrue( torch.allclose(outputs.last_hidden_state[:, :2, :2], expected_last_hidden_state_slice, atol=1e-4) )