Fix WhisperModelTest (#21883)
* force on the same device * fix tests --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
# Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222)
|
||||
model_split_percents = [0.8, 0.9]
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
@@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
input_features = inputs["input_features"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask))
|
||||
# prepare `attention_mask` with shape (batch_size, sequence_length)
|
||||
attention_mask = torch.ones(
|
||||
input_features.shape[0],
|
||||
input_features.shape[-1],
|
||||
device=input_features.device,
|
||||
dtype=input_features.dtype,
|
||||
)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user