Update namespaces inside torch.utils.data to the latest. (#13167)
* Update torch.utils.data namespaces to the latest. * Format * Update Dataloader. * Style
This commit is contained in:
@@ -77,7 +77,7 @@ class Split(Enum):
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
class MultipleChoiceDataset(Dataset):
|
class MultipleChoiceDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
return None
|
return None
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ class TokenClassificationTask:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
class TokenClassificationDataset(Dataset):
|
class TokenClassificationDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ import random
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ import random
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from typing import Optional, Union
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import nltk
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import random
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import random
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import ClassLabel, load_dataset, load_metric
|
from datasets import ClassLabel, load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class InputFeatures:
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
class HansDataset(Dataset):
|
class HansDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import copy
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import BatchSampler, Sampler
|
from torch.utils.data import BatchSampler, Sampler
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from enum import Enum
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import warnings
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from enum import Enum
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
|
|||||||
@@ -49,10 +49,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
@@ -206,16 +204,16 @@ class Trainer:
|
|||||||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
|
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
|
||||||
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
|
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
|
||||||
:func:`~transformers.DataCollatorWithPadding` otherwise.
|
:func:`~transformers.DataCollatorWithPadding` otherwise.
|
||||||
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
|
train_dataset (:obj:`torch.utils.data.Dataset` or :obj:`torch.utils.data.IterableDataset`, `optional`):
|
||||||
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed.
|
||||||
|
|
||||||
Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
|
Note that if it's a :obj:`torch.utils.data.IterableDataset` with some randomization and you are training in
|
||||||
training in a distributed fashion, your iterable dataset should either use a internal attribute
|
a distributed fashion, your iterable dataset should either use a internal attribute :obj:`generator` that
|
||||||
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identical on all
|
is a :obj:`torch.Generator` for the randomization that must be identical on all processes (and the Trainer
|
||||||
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
|
will manually set the seed of this :obj:`generator` at each epoch) or have a :obj:`set_epoch()` method that
|
||||||
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
|
internally sets the seed of the RNGs used.
|
||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||||
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed.
|
||||||
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
||||||
@@ -537,7 +535,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
return dataset.remove_columns(ignored_columns)
|
return dataset.remove_columns(ignored_columns)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -617,7 +615,7 @@ class Trainer:
|
|||||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||||
|
|
||||||
if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset):
|
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
train_dataset = IterableDatasetShard(
|
train_dataset = IterableDatasetShard(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@@ -647,7 +645,7 @@ class Trainer:
|
|||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||||
# Deprecated code
|
# Deprecated code
|
||||||
if self.args.use_legacy_prediction_loop:
|
if self.args.use_legacy_prediction_loop:
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -683,7 +681,7 @@ class Trainer:
|
|||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||||
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
||||||
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||||
"""
|
"""
|
||||||
@@ -694,7 +692,7 @@ class Trainer:
|
|||||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||||
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||||
|
|
||||||
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
|
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
eval_dataset = IterableDatasetShard(
|
eval_dataset = IterableDatasetShard(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
@@ -730,14 +728,14 @@ class Trainer:
|
|||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
test_dataset (:obj:`torch.utils.data.Dataset`, `optional`):
|
||||||
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||||
"""
|
"""
|
||||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||||
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
||||||
|
|
||||||
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
|
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
test_dataset = IterableDatasetShard(
|
test_dataset = IterableDatasetShard(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
|||||||
@@ -175,9 +175,9 @@ class TrainerCallback:
|
|||||||
The optimizer used for the training steps.
|
The optimizer used for the training steps.
|
||||||
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||||
The scheduler used for setting the learning rate.
|
The scheduler used for setting the learning rate.
|
||||||
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
train_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
|
||||||
The current dataloader used for training.
|
The current dataloader used for training.
|
||||||
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
eval_dataloader (:obj:`torch.utils.data.DataLoader`, `optional`):
|
||||||
The current dataloader used for training.
|
The current dataloader used for training.
|
||||||
metrics (:obj:`Dict[str, float]`):
|
metrics (:obj:`Dict[str, float]`):
|
||||||
The metrics computed by the last evaluation phase.
|
The metrics computed by the last evaluation phase.
|
||||||
|
|||||||
@@ -29,9 +29,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
|
||||||
|
|
||||||
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
@@ -290,7 +289,7 @@ class SequentialDistributedSampler(Sampler):
|
|||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
|
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
|
||||||
if xm.xrt_world_size() <= 1:
|
if xm.xrt_world_size() <= 1:
|
||||||
return RandomSampler(dataset)
|
return RandomSampler(dataset)
|
||||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
@@ -690,7 +689,7 @@ class IterableDatasetShard(IterableDataset):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
|
dataset (:obj:`torch.utils.data.IterableDataset`):
|
||||||
The batch sampler to split in several shards.
|
The batch sampler to split in several shards.
|
||||||
batch_size (:obj:`int`, `optional`, defaults to 1):
|
batch_size (:obj:`int`, `optional`, defaults to 1):
|
||||||
The size of the batches per shard.
|
The size of the batches per shard.
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from .deepspeed import is_deepspeed_zero3_enabled
|
from .deepspeed import is_deepspeed_zero3_enabled
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ import random
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ logger = logging.get_logger(__name__)
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user