From 1e05671d213642a67dd418c13a18a17a18c25117 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 1 May 2024 08:43:02 +0100 Subject: [PATCH] Fix QA example (#30580) * Handle cases when CLS token is absent * Use BOS token as a fallback --- examples/pytorch/question-answering/run_qa.py | 7 ++++++- .../question-answering/run_qa_beam_search.py | 14 ++++++++++++-- .../run_qa_beam_search_no_trainer.py | 14 ++++++++++++-- .../question-answering/run_qa_no_trainer.py | 7 ++++++- 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 07e3a31366..ba8e955336 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -434,7 +434,12 @@ def main(): for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py index 9f2d39540c..f5003acd96 100755 --- a/examples/pytorch/question-answering/run_qa_beam_search.py +++ b/examples/pytorch/question-answering/run_qa_beam_search.py @@ -417,7 +417,12 @@ def main(): for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 tokenized_examples["cls_index"].append(cls_index) # Grab the sequence corresponding to that example (to know what is the context and what is the question). @@ -534,7 +539,12 @@ def main(): for i, input_ids in enumerate(tokenized_examples["input_ids"]): # Find the CLS token in the input ids. - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 tokenized_examples["cls_index"].append(cls_index) # Grab the sequence corresponding to that example (to know what is the context and what is the question). diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py index 4425c1118b..9f1c6a0215 100644 --- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py @@ -444,7 +444,12 @@ def main(): for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 tokenized_examples["cls_index"].append(cls_index) # Grab the sequence corresponding to that example (to know what is the context and what is the question). @@ -563,7 +568,12 @@ def main(): for i, input_ids in enumerate(tokenized_examples["input_ids"]): # Find the CLS token in the input ids. - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 tokenized_examples["cls_index"].append(cls_index) # Grab the sequence corresponding to that example (to know what is the context and what is the question). diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py index d9f044dae4..3fbd6fdf53 100755 --- a/examples/pytorch/question-answering/run_qa_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_no_trainer.py @@ -513,7 +513,12 @@ def main(): for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] - cls_index = input_ids.index(tokenizer.cls_token_id) + if tokenizer.cls_token_id in input_ids: + cls_index = input_ids.index(tokenizer.cls_token_id) + elif tokenizer.bos_token_id in input_ids: + cls_index = input_ids.index(tokenizer.bos_token_id) + else: + cls_index = 0 # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i)