//目录

Gluon 参数读取

ndarray: save , load

from mxnet import nd
from mxnet.gluon import nn

x = nd.ones(3)
# nd.save('x',x)
# x2 = nd.load('x')
# print(x2)

y = nd.zeros(4)
# print([x,y])
# nd.save('xy',[x,y])

# x2, y2 = nd.load('xy')
# print(x2,y2)

mydict = {'x':x,'y':y}
# nd.save('mydict',mydict)

# mydict2 = nd.load('mydict')
# print(mydict2)

Gluon 模型参数:save_parameters , load_parameters

from mxnet import nd
from mxnet.gluon import nn

class MLP(nn.Block):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Dense(256,activation='relu')
        self.output = nn.Dense(10)

    def forward(self, x):
        return self.output(self.hidden(x))

# net = MLP()
# net.initialize()
# X = nd.random.uniform(shape=(2,20))
# Y = net(X)
# print(Y)
# nd.save('X',X)
# nd.save('Y',Y)

filename = 'mlp.params'
# net.save_parameters(filename)

net2 = MLP()
net2.load_parameters(filename)
X = nd.load('X')
Y = nd.load('Y')
# print(X[0])
Y2 = net2(X[0])
print(Y[0]==Y2)

 

posted @ 2019-01-08 10:16  小草的大树梦  阅读(414)  评论(0编辑  收藏  举报