Attention mask is important in the case of batching... (#16222)
* Attention mask is important in the case of batching... * Improve the fix. * Making the sentence different enough that they exhibit different predictions.
This commit is contained in:
@@ -149,7 +149,7 @@ def pad_collate_fn(tokenizer, feature_extractor):
|
|||||||
_padding_value = t_padding_value
|
_padding_value = t_padding_value
|
||||||
elif key in {"input_values", "pixel_values", "input_features"}:
|
elif key in {"input_values", "pixel_values", "input_features"}:
|
||||||
_padding_value = f_padding_value
|
_padding_value = f_padding_value
|
||||||
elif key in {"p_mask"}:
|
elif key in {"p_mask", "special_tokens_mask"}:
|
||||||
_padding_value = 1
|
_padding_value = 1
|
||||||
elif key in {"attention_mask", "token_type_ids"}:
|
elif key in {"attention_mask", "token_type_ids"}:
|
||||||
_padding_value = 0
|
_padding_value = 0
|
||||||
|
|||||||
@@ -192,7 +192,6 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
|
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
|
||||||
model_inputs = self.tokenizer(
|
model_inputs = self.tokenizer(
|
||||||
sentence,
|
sentence,
|
||||||
return_attention_mask=False,
|
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
return_special_tokens_mask=True,
|
return_special_tokens_mask=True,
|
||||||
|
|||||||
@@ -649,6 +649,23 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Batch size does not affect outputs (attention_mask are required)
|
||||||
|
sentences = ["This is a test !", "Another test this is with longer sentence"]
|
||||||
|
outputs = token_classifier(sentences)
|
||||||
|
outputs_batched = token_classifier(sentences, batch_size=2)
|
||||||
|
# Batching does not make a difference in predictions
|
||||||
|
self.assertEqual(nested_simplify(outputs_batched), nested_simplify(outputs))
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs_batched),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
|
||||||
|
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
|
||||||
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
|
||||||
|
|||||||
Reference in New Issue
Block a user