Fix WhisperModelTest (#21883)
* force on the same device * fix tests --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
# Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222)
|
||||||
|
model_split_percents = [0.8, 0.9]
|
||||||
|
|
||||||
input_name = "input_features"
|
input_name = "input_features"
|
||||||
|
|
||||||
@@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
input_features = inputs["input_features"]
|
input_features = inputs["input_features"]
|
||||||
decoder_input_ids = inputs["decoder_input_ids"]
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask))
|
# prepare `attention_mask` with shape (batch_size, sequence_length)
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
input_features.shape[0],
|
||||||
|
input_features.shape[-1],
|
||||||
|
device=input_features.device,
|
||||||
|
dtype=input_features.dtype,
|
||||||
|
)
|
||||||
|
traced_model = torch.jit.trace(
|
||||||
|
model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
)
|
||||||
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
self.fail("Couldn't trace module.")
|
self.fail("Couldn't trace module.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user