From fe085560d05b3a4a00464f9dd693dda34dc93d63 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:12:20 +0100 Subject: [PATCH] Fix `multi_gpu_data_parallel_forward` for `MusicgenTest` (#29632) update Co-authored-by: ydshieh --- tests/models/musicgen/test_modeling_musicgen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index cd978d8987..e2e7da36ea 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -103,7 +103,7 @@ class MusicgenDecoderTester: def __init__( self, parent, - batch_size=3, # need batch_size != num_hidden_layers + batch_size=4, # need batch_size != num_hidden_layers seq_length=7, is_training=False, use_labels=False, @@ -441,7 +441,7 @@ class MusicgenTester: def __init__( self, parent, - batch_size=3, # need batch_size != num_hidden_layers + batch_size=4, # need batch_size != num_hidden_layers seq_length=7, is_training=False, use_labels=False,