Fixing slow pipeline tests (#14260)
* Fiixng slow pipeline tests * Remove the image-segmentaiton override. * Fixing clamping only in training. * Wav2vec2. * Remove last mention of `no_grad`. * Fixing copies. * Rename.
This commit is contained in:
@@ -648,6 +648,7 @@ class DetrEncoderLayer(nn.Module):
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.training:
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
@@ -947,7 +947,10 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||||
# on inference mode.
|
||||
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
|
||||
@@ -948,7 +948,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||||
# on inference mode.
|
||||
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
|
||||
@@ -989,7 +989,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||||
# on inference mode.
|
||||
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
|
||||
@@ -91,9 +91,6 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def get_inference_context(self):
|
||||
return torch.no_grad
|
||||
|
||||
def preprocess(self, image):
|
||||
image = load_image(image)
|
||||
target_size = torch.IntTensor([[image.height, image.width]])
|
||||
|
||||
@@ -93,7 +93,6 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
)
|
||||
|
||||
def batch_inference(self, **inputs):
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs)
|
||||
|
||||
def sequential_inference(self, **inputs):
|
||||
@@ -101,7 +100,6 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
|
||||
handle conversational query related to a table.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
all_logits = []
|
||||
all_aggregations = []
|
||||
prev_answers = None
|
||||
|
||||
@@ -117,7 +117,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
self.assertEqual(
|
||||
nested_simplify(output, decimals=4),
|
||||
[
|
||||
{"score": 0.9809, "label": "go"},
|
||||
{"score": 0.981, "label": "go"},
|
||||
{"score": 0.0073, "label": "up"},
|
||||
{"score": 0.0064, "label": "_unknown_"},
|
||||
{"score": 0.0015, "label": "down"},
|
||||
|
||||
Reference in New Issue
Block a user