celf++

import diffusion
from heapdict import heapdict

class Node(object):
    def __init__(self, node):
        self.node = node
        self.mg1 = 0
        self.prev_best = None
        self.mg2 = 0
        self.flag = None
        self.list_index = 0

def celfpp_with_networkx(G_networkx, k, target_nodes):
    S = set()
    # Note that heapdict is min heap and hence add negative priorities for
    # it to work.
    Q = heapdict()
    last_seed = None
    cur_best = None
    node_data_list = []

    for node in G_networkx.nodes:
        node_data = Node(node)
        node_data.mg1 = ic_with_networkx(G_networkx, [node], p, mc)
        node_data.prev_best = cur_best
        node_data.mg2 = ic_with_networkx(G_networkx, [node, cur_best.node],  p, mc) if cur_best else node_data.mg1 
        node_data.flag = 0
        cur_best = cur_best if cur_best and cur_best.mg1 > node_data.mg1 else node_data
        G_networkx.nodes[node]['node_data'] = node_data
        node_data_list.append(node_data)
        node_data.list_index = len(node_data_list) - 1
        Q[node_data.list_index] = - node_data.mg1

    while len(S) < k:
        node_idx, _ = Q.peekitem()
        node_data = node_data_list[node_idx]
        if node_data.flag == len(S):
            S.add(node_data.node)
            del Q[node_idx]
            last_seed = node_data
            continue
        elif node_data.prev_best == last_seed:
            node_data.mg1 = node_data.mg2
        else:
            before = ic_with_networkx(G_networkx, S, p, mc)
            S.add(node_data.node)
            after = ic_with_networkx(G_networkx, S, p, mc)
            S.remove(node_data.node)
            node_data.mg1 = after - before
            node_data.prev_best = cur_best
            S.add(cur_best.node)
            before = ic_with_networkx(G_networkx, S, p, mc)
            S.add(node_data.node)
            after = ic_with_networkx(G_networkx, S, p, mc)
            S.remove(cur_best.node)
            if node_data.node != cur_best.node: S.remove(node_data.node)
            node_data.mg2 = after - before

        if cur_best and cur_best.mg1 < node_data.mg1:
            cur_best = node_data

        node_data.flag = len(S)
        Q[node_idx] = - node_data.mg1
    return S
posted @ 2024-01-04 13:39  X1OO  阅读(5)  评论(0)    收藏  举报