Purge unused ModelTester code (#37085)
* Purge correctly this time * Remove more methods from recent PRs * make fixup
This commit is contained in:
@@ -246,32 +246,6 @@ class Wav2Vec2BertModelTester:
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_features, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
model = Wav2Vec2BertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
input_features = input_features[:3]
|
||||
attention_mask = torch.ones(input_features.shape, device=torch_device, dtype=torch.bool)
|
||||
|
||||
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_features[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0.0
|
||||
|
||||
batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
|
||||
|
||||
for i in range(input_features.shape[0]):
|
||||
input_slice = input_features[i : i + 1, : input_lengths[i]]
|
||||
output = model(input_slice).last_hidden_state
|
||||
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def check_ctc_loss(self, config, input_features, *args):
|
||||
model = Wav2Vec2BertForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user