Fix WhisperModelTest (#21883)

* force on the same device

* fix tests

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-01 20:41:27 +01:00
committed by GitHub
parent 4edfd2d4d2
commit 36ee128375

View File

@@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
# Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222)
model_split_percents = [0.8, 0.9]
input_name = "input_features" input_name = "input_features"
@@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
input_features = inputs["input_features"] input_features = inputs["input_features"]
decoder_input_ids = inputs["decoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"] decoder_attention_mask = inputs["decoder_attention_mask"]
traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask)) # prepare `attention_mask` with shape (batch_size, sequence_length)
attention_mask = torch.ones(
input_features.shape[0],
input_features.shape[-1],
device=input_features.device,
dtype=input_features.dtype,
)
traced_model = torch.jit.trace(
model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask)
)
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")