[run_clm] clarify why we get the tokenizer warning on long input (#11145)
* clarify why we get the warning here * Update examples/language-modeling/run_clm.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * wording * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
from transformers.testing_utils import CaptureLogger
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
@@ -317,7 +318,15 @@ def main():
|
|||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples[text_column_name])
|
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
||||||
|
with CaptureLogger(tok_logger) as cl:
|
||||||
|
output = tokenizer(examples[text_column_name])
|
||||||
|
# clm input could be much much longer than block_size
|
||||||
|
if "Token indices sequence length is longer than the" in cl.out:
|
||||||
|
tok_logger.warning(
|
||||||
|
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
tokenized_datasets = datasets.map(
|
tokenized_datasets = datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
|
|||||||
Reference in New Issue
Block a user