Replace dict/BatchEncoding instance checks by Mapping (#17014)
* Replace dict/BatchEncoding instance checks by Mapping * Typo
This commit is contained in:
@@ -14,11 +14,12 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..models.bert import BertTokenizer, BertTokenizerFast
|
from ..models.bert import BertTokenizer, BertTokenizerFast
|
||||||
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from ..utils import PaddingStrategy
|
from ..utils import PaddingStrategy
|
||||||
|
|
||||||
|
|
||||||
@@ -101,7 +102,7 @@ class DefaultDataCollator(DataCollatorMixin):
|
|||||||
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if not isinstance(features[0], (dict, BatchEncoding)):
|
if not isinstance(features[0], Mapping):
|
||||||
features = [vars(f) for f in features]
|
features = [vars(f) for f in features]
|
||||||
first = features[0]
|
first = features[0]
|
||||||
batch = {}
|
batch = {}
|
||||||
@@ -136,7 +137,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
if not isinstance(features[0], (dict, BatchEncoding)):
|
if not isinstance(features[0], Mapping):
|
||||||
features = [vars(f) for f in features]
|
features = [vars(f) for f in features]
|
||||||
first = features[0]
|
first = features[0]
|
||||||
batch = {}
|
batch = {}
|
||||||
@@ -177,7 +178,7 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
|||||||
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if not isinstance(features[0], (dict, BatchEncoding)):
|
if not isinstance(features[0], Mapping):
|
||||||
features = [vars(f) for f in features]
|
features = [vars(f) for f in features]
|
||||||
first = features[0]
|
first = features[0]
|
||||||
batch = {}
|
batch = {}
|
||||||
@@ -687,7 +688,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
|
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
else:
|
else:
|
||||||
batch = {
|
batch = {
|
||||||
@@ -724,7 +725,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
|
|
||||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
else:
|
else:
|
||||||
batch = {
|
batch = {
|
||||||
@@ -781,7 +782,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
|
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
else:
|
else:
|
||||||
batch = {
|
batch = {
|
||||||
@@ -858,7 +859,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
</Tip>"""
|
</Tip>"""
|
||||||
|
|
||||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
input_ids = [e["input_ids"] for e in examples]
|
input_ids = [e["input_ids"] for e in examples]
|
||||||
else:
|
else:
|
||||||
input_ids = examples
|
input_ids = examples
|
||||||
@@ -886,7 +887,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
|
|
||||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
input_ids = [e["input_ids"] for e in examples]
|
input_ids = [e["input_ids"] for e in examples]
|
||||||
else:
|
else:
|
||||||
input_ids = examples
|
input_ids = examples
|
||||||
@@ -914,7 +915,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
|
|
||||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
input_ids = [e["input_ids"] for e in examples]
|
input_ids = [e["input_ids"] for e in examples]
|
||||||
else:
|
else:
|
||||||
input_ids = examples
|
input_ids = examples
|
||||||
@@ -1207,21 +1208,21 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
|||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = _torch_collate_batch(examples, self.tokenizer)
|
batch = _torch_collate_batch(examples, self.tokenizer)
|
||||||
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||||
|
|
||||||
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = _tf_collate_batch(examples, self.tokenizer)
|
batch = _tf_collate_batch(examples, self.tokenizer)
|
||||||
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||||
|
|
||||||
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], Mapping):
|
||||||
examples = [e["input_ids"] for e in examples]
|
examples = [e["input_ids"] for e in examples]
|
||||||
batch = _numpy_collate_batch(examples, self.tokenizer)
|
batch = _numpy_collate_batch(examples, self.tokenizer)
|
||||||
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
|
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
@@ -39,7 +40,6 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation_tf_utils import TFGenerationMixin
|
from .generation_tf_utils import TFGenerationMixin
|
||||||
from .tf_utils import shape_list
|
from .tf_utils import shape_list
|
||||||
from .tokenization_utils_base import BatchEncoding
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
@@ -471,7 +471,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
||||||
)
|
)
|
||||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
elif isinstance(input_ids, Mapping):
|
||||||
if "inputs" in input_ids:
|
if "inputs" in input_ids:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" TensorFlow Hubert model."""
|
""" TensorFlow Hubert model."""
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,7 +25,6 @@ from ...activations_tf import get_tf_activation
|
|||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
||||||
from ...tf_utils import shape_list, stable_softmax
|
from ...tf_utils import shape_list, stable_softmax
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -97,7 +97,7 @@ def input_values_processing(func, config, input_values, **kwargs):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
||||||
)
|
)
|
||||||
elif isinstance(input_values, (dict, BatchEncoding)):
|
elif isinstance(input_values, Mapping):
|
||||||
if "inputs" in input_values:
|
if "inputs" in input_values:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
|
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1140,7 +1141,7 @@ class LukeTokenizer(RobertaTokenizer):
|
|||||||
"""
|
"""
|
||||||
# If we have a list of dicts, let's convert it in a dict of lists
|
# If we have a list of dicts, let's convert it in a dict of lists
|
||||||
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
||||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
|
||||||
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
||||||
|
|
||||||
# The model's main input name, usually `input_ids`, has be passed for padding
|
# The model's main input name, usually `input_ids`, has be passed for padding
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Mapping
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -1253,7 +1254,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
# If we have a list of dicts, let's convert it in a dict of lists
|
# If we have a list of dicts, let's convert it in a dict of lists
|
||||||
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
||||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
|
||||||
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
||||||
|
|
||||||
# The model's main input name, usually `input_ids`, has be passed for padding
|
# The model's main input name, usually `input_ids`, has be passed for padding
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -26,7 +27,6 @@ from ...activations_tf import get_tf_activation
|
|||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||||
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
|
||||||
from ...tf_utils import shape_list, stable_softmax
|
from ...tf_utils import shape_list, stable_softmax
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -135,7 +135,7 @@ def input_values_processing(func, config, input_values, **kwargs):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
||||||
)
|
)
|
||||||
elif isinstance(input_values, (dict, BatchEncoding)):
|
elif isinstance(input_values, Mapping):
|
||||||
if "inputs" in input_values:
|
if "inputs" in input_values:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
|
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections.abc import Mapping
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -1459,13 +1460,11 @@ def nested_simplify(obj, decimals=3):
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.tokenization_utils import BatchEncoding
|
|
||||||
|
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [nested_simplify(item, decimals) for item in obj]
|
return [nested_simplify(item, decimals) for item in obj]
|
||||||
elif isinstance(obj, np.ndarray):
|
elif isinstance(obj, np.ndarray):
|
||||||
return nested_simplify(obj.tolist())
|
return nested_simplify(obj.tolist())
|
||||||
elif isinstance(obj, (dict, BatchEncoding)):
|
elif isinstance(obj, Mapping):
|
||||||
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
||||||
elif isinstance(obj, (str, int, np.int64)):
|
elif isinstance(obj, (str, int, np.int64)):
|
||||||
return obj
|
return obj
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
@@ -2768,7 +2769,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
# If we have a list of dicts, let's convert it in a dict of lists
|
# If we have a list of dicts, let's convert it in a dict of lists
|
||||||
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
|
||||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
|
||||||
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
||||||
|
|
||||||
# The model's main input name, usually `input_ids`, has be passed for padding
|
# The model's main input name, usually `input_ids`, has be passed for padding
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
@@ -111,7 +112,7 @@ def find_batch_size(tensors):
|
|||||||
result = find_batch_size(t)
|
result = find_batch_size(t)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
elif isinstance(tensors, (dict, BatchEncoding)):
|
elif isinstance(tensors, Mapping):
|
||||||
for key, value in tensors.items():
|
for key, value in tensors.items():
|
||||||
result = find_batch_size(value)
|
result = find_batch_size(value)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user