DataLoader 与 Dataset
生活随笔
收集整理的這篇文章主要介紹了
DataLoader 与 Dataset
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一、總體概覽
?
二、具體詳解
DataLoader源碼
class DataLoader(Generic[T_co]):r"""Data loader. Combines a dataset and a sampler, and provides an iterable overthe given dataset.The :class:`~torch.utils.data.DataLoader` supports both map-style anditerable-style datasets with single- or multi-process loading, customizingloading order and optional automatic batching (collation) and memory pinning.See :py:mod:`torch.utils.data` documentation page for more details.Arguments:dataset (Dataset): dataset from which to load the data.batch_size (int, optional): how many samples per batch to load(default: ``1``).shuffle (bool, optional): set to ``True`` to have the data reshuffledat every epoch (default: ``False``).sampler (Sampler or Iterable, optional): defines the strategy to drawsamples from the dataset. Can be any ``Iterable`` with ``__len__``implemented. If specified, :attr:`shuffle` must not be specified.batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, butreturns a batch of indices at a time. Mutually exclusive with:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,and :attr:`drop_last`.num_workers (int, optional): how many subprocesses to use for dataloading. ``0`` means that the data will be loaded in the main process.(default: ``0``)collate_fn (callable, optional): merges a list of samples to form amini-batch of Tensor(s). Used when using batched loading from amap-style dataset.pin_memory (bool, optional): If ``True``, the data loader will copy Tensorsinto CUDA pinned memory before returning them. If your data elementsare a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,see the example below.drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,if the dataset size is not divisible by the batch size. If ``False`` andthe size of dataset is not divisible by the batch size, then the last batchwill be smaller. (default: ``False``)timeout (numeric, optional): if positive, the timeout value for collecting a batchfrom workers. Should always be non-negative. (default: ``0``)worker_init_fn (callable, optional): If not ``None``, this will be called on eachworker subprocess with the worker id (an int in ``[0, num_workers - 1]``) asinput, after seeding and before data loading. (default: ``None``)prefetch_factor (int, optional, keyword-only arg): Number of sample loadedin advance by each worker. ``2`` means there will be a total of2 * num_workers samples prefetched across all workers. (default: ``2``)persistent_workers (bool, optional): If ``True``, the data loader will not shutdownthe worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``).. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`cannot be an unpicklable object, e.g., a lambda function. See:ref:`multiprocessing-best-practices` on more details relatedto multiprocessing in PyTorch... warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,it instead returns an estimate based on ``len(dataset) / batch_size``, with properrounding depending on :attr:`drop_last`, regardless of multi-process loadingconfigurations. This represents the best guess PyTorch can make because PyTorchtrusts user :attr:`dataset` code in correctly handling multi-processloading to avoid duplicate data.However, if sharding results in multiple workers having incomplete last batches,this estimate can still be inaccurate, because (1) an otherwise complete batch canbe broken into multiple ones and (2) more than one batch worth of samples can bedropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect suchcases in general.See `Dataset Types`_ for more details on these two types of datasets and how:class:`~torch.utils.data.IterableDataset` interacts with`Multi-process data loading`_."""dataset: Dataset[T_co]batch_size: Optional[int]num_workers: intpin_memory: booldrop_last: booltimeout: floatsampler: Samplerprefetch_factor: int_iterator : Optional['_BaseDataLoaderIter']__initialized = Falsedef __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, sampler: Optional[Sampler[int]] = None,batch_sampler: Optional[Sampler[Sequence[int]]] = None,num_workers: int = 0, collate_fn: _collate_fn_t = None,pin_memory: bool = False, drop_last: bool = False,timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,multiprocessing_context=None, generator=None,*, prefetch_factor: int = 2,persistent_workers: bool = False):torch._C._log_api_usage_once("python.data_loader") # type: ignoreif num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')if num_workers == 0 and prefetch_factor != 2:raise ValueError('prefetch_factor option could only be specified in multiprocessing.''let num_workers > 0 to enable multiprocessing.')assert prefetch_factor > 0if persistent_workers and num_workers == 0:raise ValueError('persistent_workers option needs num_workers > 0')self.dataset = datasetself.num_workers = num_workersself.prefetch_factor = prefetch_factorself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_context# Arg-check dataset related before checking samplers because we want to# tell users that iterable-style datasets are incompatible with custom# samplers first, so that they don't learn that this combo doesn't work# after spending time fixing the custom sampler errors.if isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterable# NOTE [ Custom Samplers and IterableDataset ]## `IterableDataset` does not support custom `batch_sampler` or# `sampler` since the key is irrelevant (unless we support# generator-style dataset one day...).## For `sampler`, we always create a dummy sampler. This is an# infinite sampler even when the dataset may have an implemented# finite `__len__` because in multi-process data loading, naive# settings will return duplicated data (which may be desired), and# thus using a sampler with length matching that of dataset will# cause data lost (you may have duplicates of the first couple# batches, but never see anything afterwards). Therefore,# `Iterabledataset` always uses an infinite sampler, an instance of# `_InfiniteConstantSampler` defined above.## A custom `batch_sampler` essentially only controls the batch size.# However, it is unclear how useful it would be since an iterable-style# dataset can handle that within itself. Moreover, it is pointless# in multi-process data loading as the assignment order of batches# to workers is an implementation detail so users can not control# how to batchify each worker's iterable. Thus, we disable this# option. If this turns out to be useful in future, we can re-enable# this, and support custom samplers that specify the assignments to# specific workers.if shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with drop_last')if sampler is None: # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else: # map-styleif shuffle:# Cannot statically verify that dataset is Sized# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]sampler = RandomSampler(dataset, generator=generator) # type: ignoreelse:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerself.generator = generatorif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.persistent_workers = persistent_workersself.__initialized = Trueself._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]self._iterator = None源碼傳入參數主要如下所示:
DataLoader(dataset, batch_size=1, # 每一批數據大小shuffle=False, # sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)# 功能: 構建可迭代的數據裝載器# dataset: Dataset類,決定數據從哪讀取以及如何讀取 # batchsize: 批大小 # num_works: 是否多進程讀取數據 # shuffle: 每個epoch是否亂序 # drop_list: 當樣本數不能被batchsize整除時,是否舍棄最后一批數據# Epoch: 所有訓練樣本都以輸入到模型中,稱為一個Epoch # Iteration: 一批樣本輸入到模型中,為一個Iteration # Batchsize: 批大小,主要是決定一個Epoch有多少個Iteration樣本81, Batchsize=8;1 Epoch = 10 drop_last=True 1 Epoch = 11 drop_last=False Datasettorch.utils.data.Dataset功能: Dataset抽象類,所有自定義的Dataset需要繼承它,并且復寫getitem: 接收一個索引,返回一個樣本class Dataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`.All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite :meth:`__getitem__`, supporting fetching adata sample for a given key. Subclasses could also optionally overwrite:meth:`__len__`, which is expected to return the size of the dataset by many:class:`~torch.utils.data.Sampler` implementations and the default optionsof :class:`~torch.utils.data.DataLoader`... note:::class:`~torch.utils.data.DataLoader` by default constructs a indexsampler that yields integral indices. To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided."""def __getitem__(self, index) -> T_co:raise NotImplementedErrordef __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':return ConcatDataset([self, other])# 例子 class Dataset(object):def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB') # 0~255if self.transform is not None:img = self.transform(img)return img, label1. 讀那些數據 - Sampler輸出的Index
2. 從哪讀數據 - Dataset中的data_dir
3. 怎么讀數據 - Dataset中的getitem
總結
以上是生活随笔為你收集整理的DataLoader 与 Dataset的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Leetcode 145. 二叉树的后序
- 下一篇: Leetcode 51. N 皇后 (每