数据
1 #初始值为0,dtype为float64 2 state = theano.shared(np.array(0,dtype=np.float64)) #定义shared变量state 3 inc = T.scalar('inc',dtype=state.dtype) #统一格式为state的dtype 4 accumulator = theano.function([inc], state, updates=[(state,state+inc)]) 5 #updates作为一个参数去更新state 6 #state = state+inc
#to get variable value
1 print(state.get_value()) #state=0.0 2 accumulator(1) 3 print(state.get_value()) #state=1.0 4 accumulator(10) 5 print(state.get_value()) #state=10.0
#to set variable value
1 state.set_value(-1) #改变了state(原为10.0) 2 accumulator(3) 3 print(state.get_value()) #satte=2.0
#temporaily replace shared variable with another value
#临时方程,不更新state
1 tmp_func = state*2 +inc #临时方程 2 a = T.scalar(dtype=state.dtype) #和state一样的dtype 3 skip_shared = theano.function([inc,a],tmp_func,gives=[(state,a)]) 4 #gives让state是a,即不改变state但临时方程中使用到state的值时是a
执行
1 print(skip_shared(2,3)) #inc=2,a=3 结果为8 2 print(state.get_value()) #state不变=2
浙公网安备 33010602011771号