From 68427c9bebd1e4ff43d25b18bb9c7eb786303712 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Nov 2021 09:49:55 +0100 Subject: [PATCH] Fixing slow pipeline tests (#14260) * Fiixng slow pipeline tests * Remove the image-segmentaiton override. * Fixing clamping only in training. * Wav2vec2. * Remove last mention of `no_grad`. * Fixing copies. * Rename. --- src/transformers/models/detr/modeling_detr.py | 7 +- .../models/unispeech/modeling_unispeech.py | 5 +- .../unispeech_sat/modeling_unispeech_sat.py | 5 +- .../models/wav2vec2/modeling_wav2vec2.py | 5 +- .../pipelines/image_segmentation.py | 3 - .../pipelines/table_question_answering.py | 110 +++++++++--------- tests/test_pipelines_audio_classification.py | 2 +- 7 files changed, 71 insertions(+), 66 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index af650e75e1..5e95cc3f32 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -648,9 +648,10 @@ class DetrEncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index b4a3423516..a8a89c302b 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -947,7 +947,10 @@ class UniSpeechPreTrainedModel(PreTrainedModel): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index bd27b53edb..c5f8243bf1 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -948,7 +948,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 485983acd5..6548f245f0 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -989,7 +989,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): return input_lengths def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 84a3e67ef6..fac8cddc67 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -91,9 +91,6 @@ class ImageSegmentationPipeline(Pipeline): return super().__call__(*args, **kwargs) - def get_inference_context(self): - return torch.no_grad - def preprocess(self, image): image = load_image(image) target_size = torch.IntTensor([[image.height, image.width]]) diff --git a/src/transformers/pipelines/table_question_answering.py b/src/transformers/pipelines/table_question_answering.py index a58e3aacbe..7697752b2b 100644 --- a/src/transformers/pipelines/table_question_answering.py +++ b/src/transformers/pipelines/table_question_answering.py @@ -93,76 +93,74 @@ class TableQuestionAnsweringPipeline(Pipeline): ) def batch_inference(self, **inputs): - with torch.no_grad(): - return self.model(**inputs) + return self.model(**inputs) def sequential_inference(self, **inputs): """ Inference used for models that need to process sequences in a sequential fashion, like the SQA models which handle conversational query related to a table. """ - with torch.no_grad(): - all_logits = [] - all_aggregations = [] - prev_answers = None - batch_size = inputs["input_ids"].shape[0] + all_logits = [] + all_aggregations = [] + prev_answers = None + batch_size = inputs["input_ids"].shape[0] - input_ids = inputs["input_ids"].to(self.device) - attention_mask = inputs["attention_mask"].to(self.device) - token_type_ids = inputs["token_type_ids"].to(self.device) - token_type_ids_example = None + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs["attention_mask"].to(self.device) + token_type_ids = inputs["token_type_ids"].to(self.device) + token_type_ids_example = None - for index in range(batch_size): - # If sequences have already been processed, the token type IDs will be created according to the previous - # answer. - if prev_answers is not None: - prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) - model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) + for index in range(batch_size): + # If sequences have already been processed, the token type IDs will be created according to the previous + # answer. + if prev_answers is not None: + prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,) + model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,) - token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - for i in range(model_labels.shape[0]): - segment_id = token_type_ids_example[:, 0].tolist()[i] - col_id = token_type_ids_example[:, 1].tolist()[i] - 1 - row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - - if row_id >= 0 and col_id >= 0 and segment_id == 1: - model_labels[i] = int(prev_answers[(col_id, row_id)]) - - token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - - input_ids_example = input_ids[index] - attention_mask_example = attention_mask[index] # shape (seq_len,) token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) - outputs = self.model( - input_ids=input_ids_example.unsqueeze(0), - attention_mask=attention_mask_example.unsqueeze(0), - token_type_ids=token_type_ids_example.unsqueeze(0), - ) - logits = outputs.logits - - if self.aggregate: - all_aggregations.append(outputs.logits_aggregation) - - all_logits.append(logits) - - dist_per_token = torch.distributions.Bernoulli(logits=logits) - probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( - dist_per_token.probs.device - ) - - coords_to_probs = collections.defaultdict(list) - for i, p in enumerate(probabilities.squeeze().tolist()): + for i in range(model_labels.shape[0]): segment_id = token_type_ids_example[:, 0].tolist()[i] - col = token_type_ids_example[:, 1].tolist()[i] - 1 - row = token_type_ids_example[:, 2].tolist()[i] - 1 - if col >= 0 and row >= 0 and segment_id == 1: - coords_to_probs[(col, row)].append(p) + col_id = token_type_ids_example[:, 1].tolist()[i] - 1 + row_id = token_type_ids_example[:, 2].tolist()[i] - 1 - prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} + if row_id >= 0 and col_id >= 0 and segment_id == 1: + model_labels[i] = int(prev_answers[(col_id, row_id)]) - logits_batch = torch.cat(tuple(all_logits), 0) + token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device) - return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) + input_ids_example = input_ids[index] + attention_mask_example = attention_mask[index] # shape (seq_len,) + token_type_ids_example = token_type_ids[index] # shape (seq_len, 7) + outputs = self.model( + input_ids=input_ids_example.unsqueeze(0), + attention_mask=attention_mask_example.unsqueeze(0), + token_type_ids=token_type_ids_example.unsqueeze(0), + ) + logits = outputs.logits + + if self.aggregate: + all_aggregations.append(outputs.logits_aggregation) + + all_logits.append(logits) + + dist_per_token = torch.distributions.Bernoulli(logits=logits) + probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to( + dist_per_token.probs.device + ) + + coords_to_probs = collections.defaultdict(list) + for i, p in enumerate(probabilities.squeeze().tolist()): + segment_id = token_type_ids_example[:, 0].tolist()[i] + col = token_type_ids_example[:, 1].tolist()[i] - 1 + row = token_type_ids_example[:, 2].tolist()[i] - 1 + if col >= 0 and row >= 0 and segment_id == 1: + coords_to_probs[(col, row)].append(p) + + prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs} + + logits_batch = torch.cat(tuple(all_logits), 0) + + return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0)) def __call__(self, *args, **kwargs): r""" diff --git a/tests/test_pipelines_audio_classification.py b/tests/test_pipelines_audio_classification.py index a1cfaafe6d..1b0ad5d2cb 100644 --- a/tests/test_pipelines_audio_classification.py +++ b/tests/test_pipelines_audio_classification.py @@ -117,7 +117,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest self.assertEqual( nested_simplify(output, decimals=4), [ - {"score": 0.9809, "label": "go"}, + {"score": 0.981, "label": "go"}, {"score": 0.0073, "label": "up"}, {"score": 0.0064, "label": "_unknown_"}, {"score": 0.0015, "label": "down"},