新闻动态

pytorch 实现在测试的时候启用dropout

发布日期:2022-04-04 19:26 | 文章来源:站长之家

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
 if type(m) == nn.Dropout:
  m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
 def __init__(self):
  super(SimpleNet, self).__init__()
  self.fc = nn.Linear(8, 8)
  self.dropout = nn.Dropout(0.5)
 def forward(self, x):
  x = self.fc(x)
  x = self.dropout(x)
  return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
 if type(m) == nn.Dropout:
  m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

二.搭建神经网络

三.训练

四.对比测试结果

注意:测试过程中,一定要注意模式切换

以上为个人经验,希望能给大家一个参考,也希望大家多多支持本站。

海外服务器租用

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部