Replace dict/BatchEncoding instance checks by Mapping (#17014)

* Replace dict/BatchEncoding instance checks by Mapping

* Typo
This commit is contained in:
Sylvain Gugger
2022-04-29 17:20:52 -04:00
committed by GitHub
parent b8dffd1f3e
commit 18df440709
9 changed files with 30 additions and 26 deletions

View File

@@ -24,6 +24,7 @@ import os
import re
import warnings
from collections import OrderedDict, UserDict
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding