Add min and max question length options to TapasTokenizer (#12803)
* Add min and max question length option to the tokenizer * Add corresponding test
This commit is contained in:
@@ -262,7 +262,10 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
Whether to add empty strings instead of column names.
|
Whether to add empty strings instead of column names.
|
||||||
update_answer_coordinates (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
update_answer_coordinates (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to recompute the answer coordinates from the answer text.
|
Whether to recompute the answer coordinates from the answer text.
|
||||||
|
min_question_length (:obj:`int`, `optional`):
|
||||||
|
Minimum length of each question in terms of tokens (will be skipped otherwise).
|
||||||
|
max_question_length (:obj:`int`, `optional`):
|
||||||
|
Maximum length of each question in terms of tokens (will be skipped otherwise).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
@@ -288,6 +291,8 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
max_row_id: int = None,
|
max_row_id: int = None,
|
||||||
strip_column_names: bool = False,
|
strip_column_names: bool = False,
|
||||||
update_answer_coordinates: bool = False,
|
update_answer_coordinates: bool = False,
|
||||||
|
min_question_length=None,
|
||||||
|
max_question_length=None,
|
||||||
model_max_length: int = 512,
|
model_max_length: int = 512,
|
||||||
additional_special_tokens: Optional[List[str]] = None,
|
additional_special_tokens: Optional[List[str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -318,6 +323,8 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
max_row_id=max_row_id,
|
max_row_id=max_row_id,
|
||||||
strip_column_names=strip_column_names,
|
strip_column_names=strip_column_names,
|
||||||
update_answer_coordinates=update_answer_coordinates,
|
update_answer_coordinates=update_answer_coordinates,
|
||||||
|
min_question_length=min_question_length,
|
||||||
|
max_question_length=max_question_length,
|
||||||
model_max_length=model_max_length,
|
model_max_length=model_max_length,
|
||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -346,6 +353,8 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
self.max_row_id = max_row_id if max_row_id is not None else self.model_max_length
|
self.max_row_id = max_row_id if max_row_id is not None else self.model_max_length
|
||||||
self.strip_column_names = strip_column_names
|
self.strip_column_names = strip_column_names
|
||||||
self.update_answer_coordinates = update_answer_coordinates
|
self.update_answer_coordinates = update_answer_coordinates
|
||||||
|
self.min_question_length = min_question_length
|
||||||
|
self.max_question_length = max_question_length
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def do_lower_case(self):
|
def do_lower_case(self):
|
||||||
@@ -729,6 +738,19 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_question_tokens(self, query):
|
||||||
|
"""Tokenizes the query, taking into account the max and min question length."""
|
||||||
|
|
||||||
|
query_tokens = self.tokenize(query)
|
||||||
|
if self.max_question_length is not None and len(query_tokens) > self.max_question_length:
|
||||||
|
logger.warning("Skipping query as its tokens are longer than the max question length")
|
||||||
|
return "", []
|
||||||
|
if self.min_question_length is not None and len(query_tokens) < self.min_question_length:
|
||||||
|
logger.warning("Skipping query as its tokens are shorter than the min question length")
|
||||||
|
return "", []
|
||||||
|
|
||||||
|
return query, query_tokens
|
||||||
|
|
||||||
def _batch_encode_plus(
|
def _batch_encode_plus(
|
||||||
self,
|
self,
|
||||||
table,
|
table,
|
||||||
@@ -757,8 +779,9 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
table_tokens = self._tokenize_table(table)
|
table_tokens = self._tokenize_table(table)
|
||||||
|
|
||||||
queries_tokens = []
|
queries_tokens = []
|
||||||
for query in queries:
|
for idx, query in enumerate(queries):
|
||||||
query_tokens = self.tokenize(query)
|
query, query_tokens = self._get_question_tokens(query)
|
||||||
|
queries[idx] = query
|
||||||
queries_tokens.append(query_tokens)
|
queries_tokens.append(query_tokens)
|
||||||
|
|
||||||
batch_outputs = self._batch_prepare_for_model(
|
batch_outputs = self._batch_prepare_for_model(
|
||||||
@@ -1015,7 +1038,7 @@ class TapasTokenizer(PreTrainedTokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
table_tokens = self._tokenize_table(table)
|
table_tokens = self._tokenize_table(table)
|
||||||
query_tokens = self.tokenize(query)
|
query, query_tokens = self._get_question_tokens(query)
|
||||||
|
|
||||||
return self.prepare_for_model(
|
return self.prepare_for_model(
|
||||||
table,
|
table,
|
||||||
|
|||||||
@@ -1076,6 +1076,37 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
|
self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
|
||||||
self.assertLessEqual(len(new_encoded_inputs), 20)
|
self.assertLessEqual(len(new_encoded_inputs), 20)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_min_max_question_length(self):
|
||||||
|
data = {
|
||||||
|
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||||
|
"Age": ["56", "45", "59"],
|
||||||
|
"Number of movies": ["87", "53", "69"],
|
||||||
|
"Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
|
||||||
|
}
|
||||||
|
queries = "When was Brad Pitt born?"
|
||||||
|
table = pd.DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
# test max_question_length
|
||||||
|
tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", max_question_length=2)
|
||||||
|
|
||||||
|
encoding = tokenizer(table=table, queries=queries)
|
||||||
|
|
||||||
|
# query should not be tokenized as it's longer than the specified max_question_length
|
||||||
|
expected_results = [101, 102]
|
||||||
|
|
||||||
|
self.assertListEqual(encoding.input_ids[:2], expected_results)
|
||||||
|
|
||||||
|
# test min_question_length
|
||||||
|
tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", min_question_length=30)
|
||||||
|
|
||||||
|
encoding = tokenizer(table=table, queries=queries)
|
||||||
|
|
||||||
|
# query should not be tokenized as it's shorter than the specified min_question_length
|
||||||
|
expected_results = [101, 102]
|
||||||
|
|
||||||
|
self.assertListEqual(encoding.input_ids[:2], expected_results)
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
@is_pt_tf_cross_test
|
||||||
def test_batch_encode_plus_tensors(self):
|
def test_batch_encode_plus_tensors(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user