pytorch入门与实践-2.2-CIFAR10分类网络

1--数据载入

|----流程: DataSet->DataLoader->调用DataLoader

|----DataLoader迭代器读不到数据,无报错,一直卡住的显现:

  DataLoader的num_worker设置为0 解决,这个多线程读取数据不知道为什么会出现这个bug?

2--定义网络

|----与之前的一致

 

3--损失函数和优化器

|----损失函数: 交叉熵

|----优化器:SGD (?)

 

4--网络训练

|----读入数据

|----输入网络

|----获得结果,计算loss

|----清空梯度

|----反向传播

|--训练结束

 

5--使用cuda加速

|----先将网络载入显卡 net.cuda(),再将运算的数据也载如cuda, x.cuda()

|----注意处:但是网络传播时输入的还应该是 Variable对象,只是底层利用了cuda对象实现

|----问题:?虽然占用了部分GPU但是没感觉有啥速度上的变化? batch增大后会卡死(batch==64时巨卡)?

|---原因分析: 也许是因为每个batch单独载入GPU的问题? 如何在开始时一次性载入 ? 如何不重复载入?

 

6--第二章总结

|----pytorch 三元素 Tensor,Variable,NN

posted @ 2019-02-27 14:08  ChenXianRen  阅读(273)  评论(0编辑  收藏  举报