apply_nodes_func

import torch as th
import dgl

g=dgl.DGLGraph()
g.add_nodes(3)
g.ndata["x"]=th.ones(3,4) #number of features to match number of nodes

print("nodes",g.nodes())
print("ndata",g.ndata["x"])

#increment the node feature by 1
def increment_feature(nodes):
    return {"x":nodes.data["x"]+1}

g.apply_nodes(func=increment_feature,v=[0,2])#apply func to nodes 0 and 2
print("ndata",g.ndata["x"])


'''
ndata tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
ndata tensor([[2., 2., 2., 2.],
        [1., 1., 1., 1.],
        [2., 2., 2., 2.]])


'''

  

posted on 2019-09-25 19:32  happygril3  阅读(302)  评论(0)    收藏  举报

导航