From 18df440709f1b19d1c5617c0d987c5ff8fd0915d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 29 Apr 2022 17:20:52 -0400 Subject: [PATCH] Replace dict/BatchEncoding instance checks by Mapping (#17014) * Replace dict/BatchEncoding instance checks by Mapping * Typo --- src/transformers/data/data_collator.py | 27 ++++++++++--------- src/transformers/modeling_tf_utils.py | 4 +-- .../models/hubert/modeling_tf_hubert.py | 4 +-- .../models/luke/tokenization_luke.py | 3 ++- .../models/mluke/tokenization_mluke.py | 3 ++- .../models/wav2vec2/modeling_tf_wav2vec2.py | 4 +-- src/transformers/testing_utils.py | 5 ++-- src/transformers/tokenization_utils_base.py | 3 ++- src/transformers/trainer_pt_utils.py | 3 ++- 9 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 0c9276e948..fc1dd25eb3 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -14,11 +14,12 @@ import random import warnings +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from ..models.bert import BertTokenizer, BertTokenizerFast -from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase +from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import PaddingStrategy @@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin): def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: import torch - if not isinstance(features[0], (dict, BatchEncoding)): + if not isinstance(features[0], Mapping): features = [vars(f) for f in features] first = features[0] batch = {} @@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: import numpy as np import tensorflow as tf - if not isinstance(features[0], (dict, BatchEncoding)): + if not isinstance(features[0], Mapping): features = [vars(f) for f in features] first = features[0] batch = {} @@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: import numpy as np - if not isinstance(features[0], (dict, BatchEncoding)): + if not isinstance(features[0], Mapping): features = [vars(f) for f in features] first = features[0] batch = {} @@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): import tensorflow as tf # Handle dict or lists with proper padding and conversion to tensor. - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of) else: batch = { @@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of) else: batch = { @@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): import numpy as np # Handle dict or lists with proper padding and conversion to tensor. - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of) else: batch = { @@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): """ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: input_ids = examples @@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): return {"input_ids": inputs, "labels": labels} def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: input_ids = examples @@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): return {"input_ids": inputs, "labels": labels} def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): input_ids = [e["input_ids"] for e in examples] else: input_ids = examples @@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin): return_tensors: str = "pt" def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): examples = [e["input_ids"] for e in examples] batch = _torch_collate_batch(examples, self.tokenizer) inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): examples = [e["input_ids"] for e in examples] batch = _tf_collate_batch(examples, self.tokenizer) inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch) return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: - if isinstance(examples[0], (dict, BatchEncoding)): + if isinstance(examples[0], Mapping): examples = [e["input_ids"] for e in examples] batch = _numpy_collate_batch(examples, self.tokenizer) inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index efa37e32bd..dacacbb28a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -21,6 +21,7 @@ import os import pickle import re import warnings +from collections.abc import Mapping from typing import Dict, List, Optional, Union import h5py @@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation_tf_utils import TFGenerationMixin from .tf_utils import shape_list -from .tokenization_utils_base import BatchEncoding from .utils import ( DUMMY_INPUTS, HUGGINGFACE_CO_RESOLVE_ENDPOINT, @@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs): raise ValueError( f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." ) - elif isinstance(input_ids, (dict, BatchEncoding)): + elif isinstance(input_ids, Mapping): if "inputs" in input_ids: warnings.warn( "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.", diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index eb79815f1a..540090871f 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -15,6 +15,7 @@ """ TensorFlow Hubert model.""" import inspect import warnings +from collections.abc import Mapping from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...tf_utils import shape_list, stable_softmax -from ...tokenization_utils_base import BatchEncoding from ...utils import ( ModelOutput, add_start_docstrings, @@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs): raise ValueError( f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." ) - elif isinstance(input_values, (dict, BatchEncoding)): + elif isinstance(input_values, Mapping): if "inputs" in input_values: warnings.warn( "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.", diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index e35db36aed..e75fda42ca 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -17,6 +17,7 @@ import itertools import json import os +from collections.abc import Mapping from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer): """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} # The model's main input name, usually `input_ids`, has be passed for padding diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py index 1ddf472d56..24a6304fc1 100644 --- a/src/transformers/models/mluke/tokenization_mluke.py +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -18,6 +18,7 @@ import itertools import json import os +from collections.abc import Mapping from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union @@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer): """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} # The model's main input name, usually `input_ids`, has be passed for padding diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index 9bbb908eb0..bac62f148c 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -16,6 +16,7 @@ import inspect import warnings +from collections.abc import Mapping from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union @@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...tf_utils import shape_list, stable_softmax -from ...tokenization_utils_base import BatchEncoding from ...utils import ( ModelOutput, add_start_docstrings, @@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs): raise ValueError( f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}." ) - elif isinstance(input_values, (dict, BatchEncoding)): + elif isinstance(input_values, Mapping): if "inputs" in input_values: warnings.warn( "The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6e4546afb1..86d3673b74 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -22,6 +22,7 @@ import shutil import sys import tempfile import unittest +from collections.abc import Mapping from distutils.util import strtobool from io import StringIO from pathlib import Path @@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3): """ import numpy as np - from transformers.tokenization_utils import BatchEncoding - if isinstance(obj, list): return [nested_simplify(item, decimals) for item in obj] elif isinstance(obj, np.ndarray): return nested_simplify(obj.tolist()) - elif isinstance(obj, (dict, BatchEncoding)): + elif isinstance(obj, Mapping): return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} elif isinstance(obj, (str, int, np.int64)): return obj diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index d75b05c057..f587cc060d 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -24,6 +24,7 @@ import os import re import warnings from collections import OrderedDict, UserDict +from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): """ # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader - if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} # The model's main input name, usually `input_ids`, has be passed for padding diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index d76552c375..ac83826e40 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -22,6 +22,7 @@ import math import os import sys import warnings +from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass from logging import StreamHandler @@ -111,7 +112,7 @@ def find_batch_size(tensors): result = find_batch_size(t) if result is not None: return result - elif isinstance(tensors, (dict, BatchEncoding)): + elif isinstance(tensors, Mapping): for key, value in tensors.items(): result = find_batch_size(value) if result is not None: