资源描述:
《手写数字识别(mxnet官网例子).doc》由会员上传分享,免费在线阅读,更多相关内容在行业资料-天天文库。
1、手写数字识别简介:通过MNIST数据集建立一个手写数字分类器。(MNIST对于手写数据分类任务是一个广泛使用的数据集)。1.前提:mxnet0.10及以上、python、jupyternotebook(有时间可以jupyternotebook的用法,如:PPT的制作)pipinstallrequestsjupyter——python下jupyternotebook的安装2.加载数据集:importmxnetasmxmnist=mx.test_utils.get_mnist()此时MXNET数据集已完全加载到内存中(注
2、:此法对于大型数据集不适用)考虑要素:快速高效地从源直接流数据+输入样本的顺序图像通常用4维数组来表示:(batch_size,num_channels,width,height)对于MNIST数据集,因为是28*28灰度图像,所以只有1个颜色通道,width=28,height=28,本例中batch=100(批处理100),即输入形状是(batch_size,1,28,28)数据迭代器通过随机的调整输入来解决连续feed相同样本的问题。测试数据的顺序无关紧要。batch_size=100train_iter=mx
3、.io.NDArrayIter(mnist['train_data'],mnist['train_label'],batch_size,shuffle=True)val_iter=mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size)——初始化MNIST数据集的数据迭代器(2个:训练数据+测试数据)3.训练+预测:(2种方法)(CNN优于MLP)1)传统深度神经网络结构——MLP(多层神经网络)MLP——MXNET的符号接口为输入的数据
4、创建一个占位符变量data=mx.sym.var('data')data=mx.sym.flatten(data=data)——将数据从4维变成2维(batch_size,num_channel*width*height)fc1=mx.sym.FullyConnected(data=data,num_hidden=128)act1=mx.sym.Activation(data=fc1,act_type="relu")——第一个全连接层及相应的激活函数fc2=mx.sym.FullyConnected(data=act
5、1,num_hidden=64)act2=mx.sym.Activation(data=fc2,act_type="relu")——第二个全连接层及相应的激活函数(声明2个全连接层,每层有128个和64个神经元)fc3=mx.sym.FullyConnected(data=act2,num_hidden=10)——声明大小10的最终完全连接层mlp=mx.sym.SoftmaxOutput(data=fc3,name='softmax')——softmax的交叉熵损失MNIST的MLP网络结构以上,已完成了数据迭代器
6、和神经网络的申明,下面可以进行训练。超参数:处理大小、学习速率importlogginglogging.getLogger().setLevel(logging.DEBUG)——记录到标准输出mlp_model=mx.mod.Module(symbol=mlp,context=mx.cpu())——在CPU上创建一个可训练的模块mlp_model.fit(train_iter——训练数据eval_data=val_iter,——验证数据optimizer='sgd',——使用SGD训练optimizer_params
7、={'learning_rate':0.1},——使用固定的学习速率eval_metric='acc',——训练过程中报告准确性batch_end_callback=mx.callback.Speedometer(batch_size,100),——每批次100数据输出的进展num_epoch=10)——训练至多通过10个数据预测:test_iter=mx.io.NDArrayIter(mnist['test_data'],None,batch_size)prob=mlp_model.predict(test_ite
8、r)assertprob.shape==(10000,10)——计算每一个测试图像可能的预测得分(prob[i][j]第i个测试图像包含j输出类)test_iter=mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size)——预测精度的方法acc=mx.metric