[Modeling] Fix encoder CPU offloading for whisper (#38994)
* fix cpu offloading for whisper Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * unskip offloading tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * revert small change Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -687,9 +687,9 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
|
||||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||
embed_pos = self.embed_positions.weight
|
||||
all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = inputs_embeds + self.embed_positions(all_positions)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
|
||||
@@ -3356,22 +3356,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True)
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Not applicable for an encoder-only acoustic model")
|
||||
def test_inputs_embeds(self):
|
||||
# input embeds is meaningless for an encoder-only acoustic model
|
||||
|
||||
Reference in New Issue
Block a user