博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Softmax实现 fashion.mnist 分类
阅读量:4879 次
发布时间:2019-06-11

本文共 2582 字,大约阅读时间需要 8 分钟。

 

 

softmax

 

 

#!/usr/bin/env python# coding: utf-8# In[1]:get_ipython().run_line_magic('matplotlib', 'inline')import gluonbook as gbfrom mxnet import autograd,nd# In[2]:batch_size = 256train_iter,test_iter = gb.load_data_fashion_mnist(batch_size)# In[3]:num_inputs = 784num_outputs = 10W = nd.random.normal(scale=0.01,shape=(num_inputs,num_outputs))b = nd.zeros(num_outputs)# In[4]:W.attach_grad()b.attach_grad()# softmax运算# In[5]:X = nd.array([[1,2,3],[4,5,6]])X.sum(axis=0,keepdims=True)# In[6]:def softmax(X):    X_exp = X.exp()    partition = X_exp.sum(axis = 1,keepdims = True)    return X_exp / partition# 例如# In[7]:X = nd.random.normal(shape=(2,5))X_prob = softmax(X)X_prob,X_prob.sum(axis=1)# 定义模型# In[8]:def net(X):    return softmax(nd.dot(X.reshape((-1,num_inputs)),W)+b)# 定义损失函数# In[9]:y_hat = nd.array([[0.1,0.3,0.6],[0.3,0.2,0.5]])y = nd.array([0,2])nd.pick(y_hat,y)# 交叉熵损失函数# In[10]:def cross_entropy(y_hat,y):    return - nd.pick(y_hat,y).log()# 计算分类准确率# In[11]:def accuracy(y_hat,y):    return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()# In[12]:accuracy(y_hat,y)# 评价 net 在 data_iter上的准确率# In[13]:def evaluate_accuracy(data_iter,net):    acc = 0    for X,y in data_iter:        acc += accuracy(net(X),y)    return acc / len(data_iter)# In[14]:evaluate_accuracy(test_iter,net)# 训练模型# In[15]:num_epochs, lr = 5, 0.1def train(net, train_iter, test_iter, loss, num_epochs, batch_size,              params=None, lr=None, trainer=None):    for epoch in range(num_epochs):        train_l_sum = 0        train_acc_sum = 0        for X, y in train_iter:     # 对每个样本预测            with autograd.record():                y_hat = net(X)      # 计算预测值 XW+b                l = loss(y_hat, y)  # 计算交叉熵函数            l.backward()            # 交叉熵函数求导            gb.sgd(params, lr, batch_size)   # 修改参数 W,b            train_l_sum += l.mean().asscalar()            train_acc_sum += accuracy(y_hat, y)        test_acc = evaluate_accuracy(test_iter, net)        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'              % (epoch + 1, train_l_sum / len(train_iter),                 train_acc_sum / len(train_iter), test_acc))train(net, train_iter, test_iter, cross_entropy, num_epochs,batch_size, [W, b], lr)# 预测# In[16]:for X, y in test_iter:    breaktrue_labels = gb.get_fashion_mnist_labels(y.asnumpy())pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]gb.show_fashion_mnist(X[0:9], titles[0:9])
View Code

 

 

 

 

转载于:https://www.cnblogs.com/TreeDream/p/10002918.html

你可能感兴趣的文章
SQLServer之数据库行锁
查看>>
OFDM仿真
查看>>
浅谈linux内核中内存分配函数
查看>>
走近SpringBoot
查看>>
写在读研初期
查看>>
开环增益对负反馈放大电路的影响
查看>>
MySQL-ERROR 2003
查看>>
SQL Server2012-SSIS的包管理和部署
查看>>
JavaScript内置对象
查看>>
如何把js的循环写成异步的
查看>>
ER图是啥?
查看>>
too many include files depth = 1024错误原因
查看>>
HTTP协议详解(三)
查看>>
Android零基础入门第84节:引入Fragment原来是这么回事
查看>>
解析SQL Server之任务调度
查看>>
参考资料地址
查看>>
08.路由规则中定义参数
查看>>
Pandas截取列部分字符,并据此修改另一列的数据
查看>>
java.lang.IllegalArgumentException
查看>>
【Spark】编程实战之模拟SparkRPC原理实现自定义RPC
查看>>