Update namespaces inside torch.utils.data to the latest. (#13167)

* Update torch.utils.data namespaces to the latest.

* Format

* Update Dataloader.

* Style
This commit is contained in:
Allan Lin
2021-08-19 20:29:51 +08:00
committed by GitHub
parent 1fec32adc6
commit 91ff480e26
24 changed files with 41 additions and 44 deletions

View File

@@ -77,7 +77,7 @@ class Split(Enum):
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class MultipleChoiceDataset(Dataset): class MultipleChoiceDataset(Dataset):
""" """

View File

@@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
) )
return scheduler return scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None return None
elif is_torch_tpu_available(): elif is_torch_tpu_available():

View File

@@ -206,7 +206,7 @@ class TokenClassificationTask:
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class TokenClassificationDataset(Dataset): class TokenClassificationDataset(Dataset):
""" """

View File

@@ -31,7 +31,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -31,7 +31,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -29,7 +29,7 @@ from typing import Optional, Union
import datasets import datasets
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -29,7 +29,7 @@ import nltk
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -21,7 +21,7 @@ import random
import datasets import datasets
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -27,7 +27,7 @@ import random
import datasets import datasets
import torch import torch
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -28,7 +28,7 @@ import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -88,7 +88,7 @@ class InputFeatures:
if is_torch_available(): if is_torch_available():
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
class HansDataset(Dataset): class HansDataset(Dataset):
""" """

View File

@@ -19,7 +19,7 @@ import copy
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.data import BatchSampler, Sampler
from utils import logger from utils import logger

View File

@@ -20,7 +20,7 @@ from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock

View File

@@ -21,7 +21,7 @@ import warnings
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock

View File

@@ -19,7 +19,7 @@ from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from filelock import FileLock from filelock import FileLock

View File

@@ -49,10 +49,8 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
@@ -206,16 +204,16 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`. The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise. :func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`): train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed.
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
training in a distributed fashion, your iterable dataset should either use a internal attribute a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
:obj:`set_epoch()` method that internally sets the seed of the RNGs used. internally sets the seed of the RNGs used.
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed.
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`): tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
@@ -537,7 +535,7 @@ class Trainer:
else: else:
return dataset.remove_columns(ignored_columns) return dataset.remove_columns(ignored_columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized): if not isinstance(self.train_dataset, collections.abc.Sized):
return None return None
@@ -617,7 +615,7 @@ class Trainer:
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training") train_dataset = self._remove_unused_columns(train_dataset, description="training")
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
train_dataset = IterableDatasetShard( train_dataset = IterableDatasetShard(
train_dataset, train_dataset,
@@ -647,7 +645,7 @@ class Trainer:
pin_memory=self.args.dataloader_pin_memory, pin_memory=self.args.dataloader_pin_memory,
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
# Deprecated code # Deprecated code
if self.args.use_legacy_prediction_loop: if self.args.use_legacy_prediction_loop:
if is_torch_tpu_available(): if is_torch_tpu_available():
@@ -683,7 +681,7 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
@@ -694,7 +692,7 @@ class Trainer:
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
eval_dataset = IterableDatasetShard( eval_dataset = IterableDatasetShard(
eval_dataset, eval_dataset,
@@ -730,14 +728,14 @@ class Trainer:
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test") test_dataset = self._remove_unused_columns(test_dataset, description="test")
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset): if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1: if self.args.world_size > 1:
test_dataset = IterableDatasetShard( test_dataset = IterableDatasetShard(
test_dataset, test_dataset,

View File

@@ -175,9 +175,9 @@ class TrainerCallback:
The optimizer used for the training steps. The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`): lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate. The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training. The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
The current dataloader used for training. The current dataloader used for training.
metrics (:obj:`Dict[str, float]`): metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase. The metrics computed by the last evaluation phase.

View File

@@ -29,9 +29,8 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
@@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int): def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
if xm.xrt_world_size() <= 1: if xm.xrt_world_size() <= 1:
return RandomSampler(dataset) return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
@@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
Args: Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`): dataset (:obj:`torch.utils.data.IterableDataset`):
The batch sampler to split in several shards. The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1): batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard. The size of the batches per shard.

View File

@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from .deepspeed import is_deepspeed_zero3_enabled from .deepspeed import is_deepspeed_zero3_enabled
from .trainer import Trainer from .trainer import Trainer

View File

@@ -499,7 +499,7 @@ import random
import datasets import datasets
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers import transformers

View File

@@ -31,7 +31,7 @@ logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from transformers import Trainer from transformers import Trainer

View File

@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from torch.utils.data.dataset import Dataset from torch.utils.data import Dataset
from transformers import Trainer from transformers import Trainer