数据

1 #初始值为0,dtype为float64
2 state = theano.shared(np.array(0,dtype=np.float64))    #定义shared变量state
3 inc = T.scalar('inc',dtype=state.dtype)    #统一格式为state的dtype
4 accumulator = theano.function([inc], state, updates=[(state,state+inc)])
5 #updates作为一个参数去更新state
6 #state = state+inc

 

#to get variable value

1 print(state.get_value())    #state=0.0
2 accumulator(1)             
3 print(state.get_value())    #state=1.0
4 accumulator(10)        
5 print(state.get_value())    #state=10.0

 

#to set variable value

1 state.set_value(-1)    #改变了state(原为10.0)
2 accumulator(3)        
3 print(state.get_value())    #satte=2.0

 

#temporaily replace shared variable with another value
#临时方程,不更新state
1 tmp_func = state*2 +inc    #临时方程
2 a = T.scalar(dtype=state.dtype)    #和state一样的dtype
3 skip_shared = theano.function([inc,a],tmp_func,gives=[(state,a)])
4 #gives让state是a,即不改变state但临时方程中使用到state的值时是a

 

执行

1 print(skip_shared(2,3))    #inc=2,a=3    结果为8
2 print(state.get_value())    #state不变=2

 

 
posted on 2022-08-02 15:17  Jolyne123  阅读(46)  评论(0)    收藏  举报