新闻动态

PyTorch零基础入门之逻辑斯蒂回归

发布日期:2021-12-23 13:51 | 文章来源:gibhub

学习总结

(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_datay_data这里也是矩阵的形式)。

一、sigmoid函数

logistic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:

二、和Linear的区别

逻辑斯蒂和线性模型的unit区别如下图:

sigmoid函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。

如下图右方表格所示,当 y ^ \hat{y} y^​越接近y时则BCE Loss值越小。

三、逻辑斯蒂回归(分类)PyTorch实现

# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021
@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt  
import torch.nn.functional as F
import numpy as np
# 准备数据
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])

losslst = []
class LogisticRegressionModel(nn.Module):
 def __init__(self):
  super(LogisticRegressionModel, self).__init__()
  self.linear = torch.nn.Linear(1, 1)
  
 def forward(self, x):
 	# 和线性模型的网络的唯一区别在这句,多了F.sigmoid
  y_predict = F.sigmoid(self.linear(x))
  return y_predict
 
model = LogisticRegressionModel()
# 使用交叉熵作损失函数
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), 
lr = 0.01)
# 训练
for epoch in range(1000):
 y_predict = model(x_data)
 loss = criterion(y_predict, y_data)
 # 打印loss对象会自动调用__str__
 print(epoch, loss.item())
 losslst.append(loss.item())
 # 梯度清零后反向传播
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
# 画图
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()

# test
# 每周学习的时间,200个点
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 画 probability of pass = 0.5的红色横线
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

可以看出处于通过和不通过的分界线是Hours=2.5。

Reference

pytorch官方文档

到此这篇关于PyTorch零基础入门之逻辑斯蒂回归的文章就介绍到这了,更多相关PyTorch 逻辑斯蒂回归内容请搜索本站以前的文章或继续浏览下面的相关文章希望大家以后多多支持本站!

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部