新闻动态

pytorch实现手写数字图片识别

发布日期:2022-06-01 14:29 | 文章来源:脚本之家

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot
# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,)
 )
])),
 batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,)
 )
])),
 batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")
class Net(nn.Module):
 def __init__(self):
  super(Net, self).__init__()
  self.fc1 = nn.Linear(28*28, 256)
  self.fc2 = nn.Linear(256, 64)
  self.fc3 = nn.Linear(64, 10)
 def forward(self, x):
  # x: [b, 1, 28, 28]
  # h1 = relu(xw1 + b1)
  x = F.relu(self.fc1(x))
  # h2 = relu(h1w2 + b2)
  x = F.relu(self.fc2(x))
  # h3 = h2w3 + b3
  x = self.fc3(x)
  return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []
for epoch in range(3):
 for batch_idx, (x, y) in enumerate(train_loader):
  #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
  #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
  x = x.view(x.size(0), 28*28)
  #  [b, 10]
  out = net(x)
  y_onehot = one_hot(y)
  # loss = mse(out, y_onehot)
  loss = F.mse_loss(out, y_onehot)
  optimizer.zero_grad()
  loss.backward()
  # w' = w - lr*grad
  optimizer.step()
  train_loss.append(loss.item())
  if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
 # we get optimal [w1, b1, w2, b2, w3, b3]

total_correct = 0
for x,y in test_loader:
 x = x.view(x.size(0), 28*28)
 out = net(x)
 # out: [b, 10]
 pred = out.argmax(dim=1)
 correct = pred.eq(y).sum().float().item()
 total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from matplotlib import pyplot as plt

def plot_curve(data):
 fig = plt.figure()
 plt.plot(range(len(data)), data, color='blue')
 plt.legend(['value'], loc='upper right')
 plt.xlabel('step')
 plt.ylabel('value')
 plt.show()

def plot_image(img, label, name):
 fig = plt.figure()
 for i in range(6):
  plt.subplot(2, 3, i + 1)
  plt.tight_layout()
  plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
  plt.title("{}: {}".format(name, label[i].item()))
  plt.xticks([])
  plt.yticks([])
 plt.show()

def one_hot(label, depth=10):
 out = torch.zeros(label.size(0), depth)
 idx = torch.LongTensor(label).view(-1, 1)
 out.scatter_(dim=1, index=idx, value=1)
 return out

打印出损失下降的曲线图:

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持本站。

海外稳定服务器

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部