[MMS] Fix mms (#25267)
* [MMS] Fix mms * [MMS] Fix mms * fix mms loading * Apply suggestions from code review * make style * Update tests/models/wav2vec2/test_modeling_wav2vec2.py
This commit is contained in:
committed by
GitHub
parent
ad8321512d
commit
b28ebb2655
@@ -3047,6 +3047,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
offload_state_dict = True
|
offload_state_dict = True
|
||||||
|
|
||||||
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
|
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
|
||||||
|
|
||||||
|
# tie the model weights before retrieving the state_dict
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
@@ -3092,7 +3096,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model_buffers = {".".join([prefix, key]) for key in model_buffers}
|
model_buffers = {".".join([prefix, key]) for key in model_buffers}
|
||||||
unexpected_keys = list(unexpected_keys - model_buffers)
|
unexpected_keys = list(unexpected_keys - model_buffers)
|
||||||
|
|
||||||
model.tie_weights()
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
ptrs = collections.defaultdict(list)
|
ptrs = collections.defaultdict(list)
|
||||||
for name, tensor in model.state_dict().items():
|
for name, tensor in model.state_dict().items():
|
||||||
|
|||||||
@@ -1151,6 +1151,43 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||||
|
|
||||||
|
# test that loading adapter weights with mismatched vocab sizes can be loaded
|
||||||
|
def test_load_target_lang_with_mismatched_size(self):
|
||||||
|
processor = Wav2Vec2Processor.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_logits(model, input_features):
|
||||||
|
model = model.to(torch_device)
|
||||||
|
batch = processor(
|
||||||
|
input_features,
|
||||||
|
padding=True,
|
||||||
|
sampling_rate=processor.feature_extractor.sampling_rate,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(
|
||||||
|
input_values=batch["input_values"].to(torch_device),
|
||||||
|
attention_mask=batch["attention_mask"].to(torch_device),
|
||||||
|
).logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
input_features = [np.random.random(16_000 * s) for s in [1, 3, 2, 6]]
|
||||||
|
|
||||||
|
model = Wav2Vec2ForCTC.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-wav2vec2-adapter", target_lang="fr", ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = get_logits(model, input_features)
|
||||||
|
|
||||||
|
model_2 = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||||
|
model_2.load_adapter("fr")
|
||||||
|
|
||||||
|
logits_2 = get_logits(model_2, input_features)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||||
|
|
||||||
def test_load_attn_adapter(self):
|
def test_load_attn_adapter(self):
|
||||||
processor = Wav2Vec2Processor.from_pretrained(
|
processor = Wav2Vec2Processor.from_pretrained(
|
||||||
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
"hf-internal-testing/tiny-random-wav2vec2", return_attention_mask=True
|
||||||
|
|||||||
Reference in New Issue
Block a user