Fix multi_gpu_data_parallel_forward for MusicgenTest (#29632)
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -103,7 +103,7 @@ class MusicgenDecoderTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=3, # need batch_size != num_hidden_layers
|
batch_size=4, # need batch_size != num_hidden_layers
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_labels=False,
|
use_labels=False,
|
||||||
@@ -441,7 +441,7 @@ class MusicgenTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=3, # need batch_size != num_hidden_layers
|
batch_size=4, # need batch_size != num_hidden_layers
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_labels=False,
|
use_labels=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user