update_all_fun(send recv)




'''
Send messages through all edges >>> update all nodes.
DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')

message_func --message function on the edges
reduce_func--reduce function on the node
apply_node_func:apply function on the nodes
'''


'''
DGLGraph.send(edges='__ALL__', message_func='default')
edges:
int:one edge using edge id
pair of int :one edge using its endpoints
int iterable/tensor :multiple edges using edge id
pair of int iterable/pair of tensor :multiple edges using their endpoints

returns messages on the edges and can be later fetched in the destination node’s mailbox

'''


'''
DGLGraph.recv(v='__ALL__', reduce_func='default', apply_node_func='default', inplace=False)

'''

import warnings
warnings.filterwarnings("ignore")
import torch as th
import dgl
g=dgl.DGLGraph()

g.add_nodes(3)
g.ndata["x"]=th.tensor([[5.],[6.],[7.]])
g.add_edges([0,1],[1,2])
src=th.tensor([0])
dst=th.tensor([2])
g.add_edges(src,dst)
print("ndata",g.ndata["x"])


def send_source(edges):

    print("src",edges.src["x"].shape,edges.src["x"])  #源节点特征  ([2, 1])
    print("dst",edges.dst["x"].shape,edges.dst["x"])  #目标节点特征 ([2, 1])


    return {"m":edges.src["x"]}

g.register_message_func(send_source)

'''
ndata tensor([[5.],
        [6.],
        [7.]])
src torch.Size([3, 1]) tensor([[5.],
        [6.],
        [5.]])
dst torch.Size([3, 1]) tensor([[6.],
        [7.],
        [7.]])

'''




def simple_reduce(nodes):
    print("data_nodes",nodes.data["x"])  #节点特征
    print("mailbox",nodes.mailbox["m"].shape,nodes.mailbox["m"])  #mailbox包含沿第二维堆叠的所有传入message特征 [2, 1, 1]
    print("sum",nodes.mailbox["m"].sum(1))



    return {"x":nodes.mailbox["m"].sum(1)}  #按行求和

g.register_reduce_func(simple_reduce)



g.send(g.edges())
g.recv(g.nodes())
print("ndata",g.ndata["x"])


'''

data_nodes tensor([[6.]])
mailbox torch.Size([1, 1, 1]) tensor([[[5.]]])
sum tensor([[5.]])
data_nodes tensor([[7.]])
mailbox torch.Size([1, 2, 1]) tensor([[[6.],
         [5.]]])
sum tensor([[11.]])
ndata tensor([[ 0.],
        [ 5.],
        [11.]])

'''

  

posted on 2019-09-25 17:41  happygril3  阅读(486)  评论(0)    收藏  举报

导航