新闻动态

Pytorch dataloader在加载最后一个batch时卡死的解决

发布日期:2022-04-02 10:53 | 文章来源:源码中国

问题:

自己写了个dataloader,为了部署方便,用OpenCV的接口进行数据读取,而没有用PIL,代码大致如下:

 def __getitem__(self, idx):
  sample = self.samples[idx]
 
  img = cv2.imread(sample[0])
  img = cv2.resize(img, tuple(self.input_size))
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  # if not self.val and random.randint(1, 10) < 3:
  #  img = self.img_aug(img)
  img = Image.fromarray(img) 
  img = self.transforms(img)  
  ...

结果在训练过程中,在第1个epoch的最后一个batch时,程序卡死。

解决方案:

可能是因为OpenCV与Pytorch互锁的问题,关闭OpenCV的多线程,问题解决。

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

补充:pytorch 中一个batch的训练过程

# 一般情况下
optimizer.zero_grad() # 梯度清零
preds = model(inputs) # inference,前向传播求出预测值
loss = criterion(preds, targets)  # 计算loss
loss.backward() # 反向传播求解梯度
optimizer.step()# 更新权重,更新网络权重参数

此外,反向传播前,如果不进行梯度清零,则可以实现梯度累加,从而一定程度上解决显存受限的问题。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持本站。

美国快速服务器

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部