pytorch中dataloader()过程中会遇到的问题有哪些-亚博电竞手机版
这篇文章将为大家详细讲解有关pytorch中dataloader()过程中会遇到的问题有哪些,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
如下所示:
runtimeerror: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2
train_dataset=datasets.imagefolder( traindir, transforms.compose([ transforms.resize((224))###
原因是
transforms.resize() 的参数设置问题,改为如下设置就可以了
train_dataset=datasets.imagefolder( traindir, transforms.compose([ transforms.resize((224,224)),
同理,val_dataset中也调整为transforms.resize((224,224))。
补充:pytorch之dataloader深入剖析
- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;
- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;
- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;
- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;
① dataloader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
② queue的特点
当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
当数据满了: queue.put() 会阻塞
③ dataloader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
输入数据pipeline
pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 dataset 对象
② 创建一个 dataloader 对象
③ 循环这个 dataloader 对象,将img, label加载到模型中进行训练
dataset=mydataset() dataloader=dataloader(dataset) num_epoches=100 forepochinrange(num_epoches): forimg,labelindataloader: ....
所以,作为直接对数据进入模型中的关键一步, dataloader非常重要。
首先简单介绍一下dataloader,它是pytorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用pytorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的dataset根据batch size大小、是否shuffle等封装成一个batch size大小的tensor,用于后面的训练。
官方对dataloader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.dataloader
先介绍一下dataloader(object)的参数:
dataset(dataset)
: 传入的数据集
batch_size(int, optional)
: 每个batch有多少个样本
shuffle(bool, optional)
: 在每个epoch开始的时候,对数据进行重新排序
sampler(sampler, optional)
: 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为false
batch_sampler(sampler, optional)
: 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——mutually exclusive)
num_workers (int, optional)
: 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
collate_fn (callable, optional)
: 将一个list的sample组成一个mini-batch的函数
pin_memory (bool, optional)
: 如果设置为true,那么data loader将会在返回它们之前,将tensors拷贝到cuda中的固定内存(cuda pinned memory)中.
drop_last (bool, optional)
: 如果设置为true:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为false(默认),那么会继续正常执行,只是最后的batch_size会小一点。
timeout(numeric, optional)
: 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
worker_init_fn (callable, optional)
: 每个worker初始化函数 if not none, this will be called on each
workersubprocesswiththeworkerid(anintin[0,num_workers-1])as input,afterseedingandbeforedataloading.(default:none)
- 首先dataloader初始化时得到datasets的采样list
classdataloader(object): r""" dataloader.combinesadatasetandasampler,andprovides single-ormulti-processiteratorsoverthedataset. arguments: dataset(dataset):datasetfromwhichtoloadthedata. batch_size(int,optional):howmanysamplesperbatchtoload (default:1). shuffle(bool,optional):setto``true``tohavethedatareshuffled ateveryepoch(default:false). sampler(sampler,optional):definesthestrategytodrawsamplesfrom thedataset.ifspecified,``shuffle``mustbefalse. batch_sampler(sampler,optional):likesampler,butreturnsabatchof indicesatatime.mutuallyexclusivewithbatch_size,shuffle, sampler,anddrop_last. num_workers(int,optional):howmanysubprocessestousefordata loading.0meansthatthedatawillbeloadedinthemainprocess. (default:0) collate_fn(callable,optional):mergesalistofsamplestoformamini-batch. pin_memory(bool,optional):if``true``,thedataloaderwillcopytensors intocudapinnedmemorybeforereturningthem. drop_last(bool,optional):setto``true``todropthelastincompletebatch, ifthedatasetsizeisnotdivisiblebythebatchsize.if``false``and thesizeofdatasetisnotdivisiblebythebatchsize,thenthelastbatch willbesmaller.(default:false) timeout(numeric,optional):ifpositive,thetimeoutvalueforcollectingabatch fromworkers.shouldalwaysbenon-negative.(default:0) worker_init_fn(callable,optional):ifnotnone,thiswillbecalledoneach workersubprocesswiththeworkerid(anintin``[0,num_workers-1]``)as input,afterseedingandbeforedataloading.(default:none) ..note::bydefault,eachworkerwillhaveitspytorchseedsetto ``base_seed worker_id``,where``base_seed``isalonggenerated bymainprocessusingitsrng.however,seedsforotherlibraies maybeduplicateduponinitializingworkers(w.g.,numpy),causing eachworkertoreturnidenticalrandomnumbers.(see :ref:`dataloader-workers-random-seed`sectioninfaq.)youmay use``torch.initial_seed()``toaccessthepytorchseedforeach workerin:attr:`worker_init_fn`,anduseittosetotherseeds beforedataloading. ..warning::if``spawn``startmethodisused,:attr:`worker_init_fn`cannotbean unpicklableobject,e.g.,alambdafunction. """ __initialized=false def__init__(self,dataset,batch_size=1,shuffle=false,sampler=none,batch_sampler=none, num_workers=0,collate_fn=default_collate,pin_memory=false,drop_last=false, timeout=0,worker_init_fn=none): self.dataset=dataset self.batch_size=batch_size self.num_workers=num_workers self.collate_fn=collate_fn self.pin_memory=pin_memory self.drop_last=drop_last self.timeout=timeout self.worker_init_fn=worker_init_fn iftimeout<0: raisevalueerror('timeoutoptionshouldbenon-negative') ifbatch_samplerisnotnone: ifbatch_size>1orshuffleorsamplerisnotnoneordrop_last: raisevalueerror('batch_sampleroptionismutuallyexclusive' 'withbatch_size,shuffle,sampler,and' 'drop_last') self.batch_size=none self.drop_last=none ifsamplerisnotnoneandshuffle: raisevalueerror('sampleroptionismutuallyexclusivewith' 'shuffle') ifself.num_workers<0: raisevalueerror('num_workersoptioncannotbenegative;' 'usenum_workers=0todisablemultiprocessing.') ifbatch_samplerisnone: ifsamplerisnone: ifshuffle: sampler=randomsampler(dataset)//将list打乱 else: sampler=sequentialsampler(dataset) batch_sampler=batchsampler(sampler,batch_size,drop_last) self.sampler=sampler self.batch_sampler=batch_sampler self.__initialized=true def__setattr__(self,attr,val): ifself.__initializedandattrin('batch_size','sampler','drop_last'): raisevalueerror('{}attributeshouldnotbesetafter{}is' 'initialized'.format(attr,self.__class__.__name__)) super(dataloader,self).__setattr__(attr,val) def__iter__(self): return_dataloaderiter(self) def__len__(self): returnlen(self.batch_sampler)
其中:randomsampler,batchsampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!
classrandomsampler(sampler): r"""sampleselementsrandomly,withoutreplacement. arguments: data_source(dataset):datasettosamplefrom """ def__init__(self,data_source): self.data_source=data_source def__iter__(self): returniter(torch.randperm(len(self.data_source)).tolist()) def__len__(self): returnlen(self.data_source)
classbatchsampler(sampler): r"""wrapsanothersamplertoyieldamini-batchofindices. args: sampler(sampler):basesampler. batch_size(int):sizeofmini-batch. drop_last(bool):if``true``,thesamplerwilldropthelastbatchif itssizewouldbelessthan``batch_size`` example: >>>list(batchsampler(sequentialsampler(range(10)),batch_size=3,drop_last=false)) [[0,1,2],[3,4,5],[6,7,8],[9]] >>>list(batchsampler(sequentialsampler(range(10)),batch_size=3,drop_last=true)) [[0,1,2],[3,4,5],[6,7,8]] """ def__init__(self,sampler,batch_size,drop_last): ifnotisinstance(sampler,sampler): raisevalueerror("samplershouldbeaninstanceof" "torch.utils.data.sampler,butgotsampler={}" .format(sampler)) ifnotisinstance(batch_size,_int_classes)orisinstance(batch_size,bool)or\ batch_size<=0: raisevalueerror("batch_sizeshouldbeapositiveintegeralvalue," "butgotbatch_size={}".format(batch_size)) ifnotisinstance(drop_last,bool): raisevalueerror("drop_lastshouldbeabooleanvalue,butgot" "drop_last={}".format(drop_last)) self.sampler=sampler self.batch_size=batch_size self.drop_last=drop_last def__iter__(self): batch=[] foridxinself.sampler: batch.append(idx) iflen(batch)==self.batch_size: yieldbatch batch=[] iflen(batch)>0andnotself.drop_last: yieldbatch def__len__(self): ifself.drop_last: returnlen(self.sampler)//self.batch_size else: return(len(self.sampler) self.batch_size-1)//self.batch_size
- 其中 _dataloaderiter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;
- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的
__getitem__()方法
- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;
class_dataloaderiter(object): r"""iteratesonceoverthedataloader'sdataset,asspecifiedbythesampler""" def__init__(self,loader): self.dataset=loader.dataset self.collate_fn=loader.collate_fn self.batch_sampler=loader.batch_sampler self.num_workers=loader.num_workers self.pin_memory=loader.pin_memoryandtorch.cuda.is_available() self.timeout=loader.timeout self.done_event=threading.event() self.sample_iter=iter(self.batch_sampler) base_seed=torch.longtensor(1).random_().item() ifself.num_workers>0: self.worker_init_fn=loader.worker_init_fn self.index_queues=[multiprocessing.queue()for_inrange(self.num_workers)] self.worker_queue_idx=0 self.worker_result_queue=multiprocessing.simplequeue() self.batches_outstanding=0 self.worker_pids_set=false self.shutdown=false self.send_idx=0 self.rcvd_idx=0 self.reorder_dict={} self.workers=[ multiprocessing.process( target=_worker_loop, args=(self.dataset,self.index_queues[i], self.worker_result_queue,self.collate_fn,base_seed i, self.worker_init_fn,i)) foriinrange(self.num_workers)] ifself.pin_memoryorself.timeout>0: self.data_queue=queue.queue() ifself.pin_memory: maybe_device_id=torch.cuda.current_device() else: #donotinitializecudacontextifnotnecessary maybe_device_id=none self.worker_manager_thread=threading.thread( target=_worker_manager_loop, args=(self.worker_result_queue,self.data_queue,self.done_event,self.pin_memory, maybe_device_id)) self.worker_manager_thread.daemon=true self.worker_manager_thread.start() else: self.data_queue=self.worker_result_queue forwinself.workers: w.daemon=true#ensurethattheworkerexitsonprocessexit w.start() _update_worker_pids(id(self),tuple(w.pidforwinself.workers)) _set_sigchld_handler() self.worker_pids_set=true #primetheprefetchloop for_inrange(2*self.num_workers): self._put_indices() def__len__(self): returnlen(self.batch_sampler) def_get_batch(self): ifself.timeout>0: try: returnself.data_queue.get(timeout=self.timeout) exceptqueue.empty: raiseruntimeerror('dataloadertimedoutafter{}seconds'.format(self.timeout)) else: returnself.data_queue.get() def__next__(self): ifself.num_workers==0:#same-processloading indices=next(self.sample_iter)#mayraisestopiteration batch=self.collate_fn([self.dataset[i]foriinindices]) ifself.pin_memory: batch=pin_memory_batch(batch) returnbatch #checkifthenextsamplehasalreadybeengenerated ifself.rcvd_idxinself.reorder_dict: batch=self.reorder_dict.pop(self.rcvd_idx) returnself._process_next_batch(batch) ifself.batches_outstanding==0: self._shutdown_workers() raisestopiteration whiletrue: assert(notself.shutdownandself.batches_outstanding>0) idx,batch=self._get_batch() self.batches_outstanding-=1 ifidx!=self.rcvd_idx: #storeout-of-ordersamples self.reorder_dict[idx]=batch continue returnself._process_next_batch(batch) next=__next__#python2compatibility def__iter__(self): returnself def_put_indices(self): assertself.batches_outstanding<2*self.num_workers indices=next(self.sample_iter,none) ifindicesisnone: return self.index_queues[self.worker_queue_idx].put((self.send_idx,indices)) self.worker_queue_idx=(self.worker_queue_idx 1)%self.num_workers self.batches_outstanding =1 self.send_idx =1 def_process_next_batch(self,batch): self.rcvd_idx =1 self._put_indices() ifisinstance(batch,exceptionwrapper): raisebatch.exc_type(batch.exc_msg) returnbatch
def_worker_loop(dataset,index_queue,data_queue,collate_fn,seed,init_fn,worker_id): global_use_shared_memory _use_shared_memory=true #intializecsidesignalhandlersforsigbusandsigsegv.pythonsignal #module'shandlersareexecutedafterpythonreturnsfromclow-level #handlers,likelywhenthesamefatalsignalhappenedagainalready. #https://docs.python.org/3/library/signal.htmlsec.18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) ifinit_fnisnotnone: init_fn(worker_id) watchdog=managerwatchdog() whiletrue: try: r=index_queue.get(timeout=manager_status_check_interval) exceptqueue.empty: ifwatchdog.is_alive(): continue else: break ifrisnone: break idx,batch_indices=r try: samples=collate_fn([dataset[i]foriinbatch_indices]) exceptexception: data_queue.put((idx,exceptionwrapper(sys.exc_info()))) else: data_queue.put((idx,samples)) delsamples
- 需要对队列操作,缓存数据,使得加载提速!
关于“pytorch中dataloader()过程中会遇到的问题有哪些”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。