Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient local-push q.pop(0) instead of q.pop() ? #8

Open
baojian opened this issue Apr 13, 2024 · 1 comment
Open

More efficient local-push q.pop(0) instead of q.pop() ? #8

baojian opened this issue Apr 13, 2024 · 1 comment

Comments

@baojian
Copy link

baojian commented Apr 13, 2024

Hi Dr. Gasteiger,

I'd like to know about the implementation of _calc_ppr_node(), where each active node is popped out (.ie., node = q.pop()). I tested this pop() method. It seems less efficient than pop(0), where the list is used as a queue, not a stack. Note that pop() removes the last element of the list. Here is my testing code:

import time
import numpy as np
import networkx as nx
from numba import njit
from numpy.linalg import norm
from numpy import sqrt
from numpy import int64
from numpy import float64

import numba


@njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
def _calc_ppr_node(inode, indptr, indices, deg, alpha, epsilon):
    alpha_eps = alpha * epsilon
    f32_0 = numba.float32(0)
    p = {inode: f32_0}
    r = {inode: alpha}
    q = [inode]
    total_opers = 0.
    while len(q) > 0:
        unode = q.pop()
        res = r[unode] if unode in r else f32_0
        if unode in p:
            p[unode] += res
        else:
            p[unode] = res
        r[unode] = f32_0
        total_opers += deg[unode]
        for vnode in indices[indptr[unode]:indptr[unode + 1]]:
            _val = (1 - alpha) * res / deg[unode]
            if vnode in r:
                r[vnode] += _val
            else:
                r[vnode] = _val
            res_vnode = r[vnode] if vnode in r else f32_0
            if res_vnode >= alpha_eps * deg[vnode]:
                if vnode not in q:
                    q.append(vnode)
    return list(p.keys()), list(p.values()), total_opers


@njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
def _calc_ppr_node_pop_first(inode, indptr, indices, deg, alpha, epsilon):
    alpha_eps = alpha * epsilon
    f32_0 = numba.float32(0)
    p = {inode: f32_0}
    r = {inode: alpha}
    q = [inode]
    total_opers = 0.
    while len(q) > 0:
        unode = q.pop(0) # only changed part is here !
        res = r[unode] if unode in r else f32_0
        if unode in p:
            p[unode] += res
        else:
            p[unode] = res
        r[unode] = f32_0
        total_opers += deg[unode]
        for vnode in indices[indptr[unode]:indptr[unode + 1]]:
            _val = (1 - alpha) * res / deg[unode]
            if vnode in r:
                r[vnode] += _val
            else:
                r[vnode] = _val
            res_vnode = r[vnode] if vnode in r else f32_0
            if res_vnode >= alpha_eps * deg[vnode]:
                if vnode not in q:
                    q.append(vnode)
    return list(p.keys()), list(p.values()), total_opers


@njit(cache=True)
def calc_ppr(indptr, indices, deg, alpha, epsilon, nodes, model='pop-last'):
    js = []
    vals = []
    opers = []
    for i, node in enumerate(nodes):
        if model == 'pop-last':
            j, val, oper = _calc_ppr_node(node, indptr, indices, deg, alpha, epsilon)
        else:
            j, val, oper = _calc_ppr_node_pop_first(node, indptr, indices, deg, alpha, epsilon)
        js.append(j)
        vals.append(val)
        opers.append(oper)
    return js, vals, opers


def toy_example_graph():
    """
    From GraphSAGE: https://arxiv.org/pdf/1706.02216.pdf
    """
    n = 15
    adj_list = {0: [1, 12], 1: [0, 2, 5, 9], 2: [1], 3: [5, 6], 4: [5],
                5: [1, 3, 4, 6, 9], 6: [3, 5], 7: [8], 8: [7, 9], 9: [1, 5, 8, 11, 12],
                10: [11], 11: [9, 10], 12: [0, 9, 13, 14], 13: [12, 14], 14: [12, 13]}
    graph = nx.Graph()
    for u in range(n):
        for v in adj_list[u]:
            graph.add_edge(u, v)
    csr_graph = nx.to_scipy_sparse_array(graph, nodelist=range(n))
    degree = csr_graph.indptr[1:] - csr_graph.indptr[:n]
    indices = csr_graph.indices
    indptr = csr_graph.indptr
    n = len(degree)
    return n, indptr, indices, degree


def test_toy_example_graph():
    n, indptr, indices, degree = toy_example_graph()
    s_node = 0  # source node
    s = np.zeros(n, dtype=np.float64)
    s[s_node] = 1.
    alpha = 0.1  # dumping factor
    eps = 1e-4  # precision parameter
    # n, indptr, indices, degree, s, alpha, eps, opt_x
    js, vals, opers = calc_ppr(indptr, indices, degree, alpha, eps, [0], model='pop-last')
    print(np.sum(vals[0]), opers[0])
    js, vals, opers = calc_ppr(indptr, indices, degree, alpha, eps, [0], model='pop-first')
    print(np.sum(vals[0]), opers[0])


test_toy_example_graph()

The above gives me the output:

0.9981211414560676 9704.0
0.9982607923448086 1062.0

@gasteigerjo
Copy link
Contributor

gasteigerjo commented Apr 15, 2024

Hmm, that's quite interesting. I didn't expect the difference to be this large!

However, Python lists are really not made for removing items from their start, so this causes a big runtime overhead. That's why there is collections.deque. Unfortunately, it seems like Numba doesn't support deque yet. Also, your operations count doesn't distinguish between iterations that enter the if-condition and those who don't. Overall, pop(0) might not be as good as it seems. What we really need to check this is measure the runtime of the two variants with timeit.

An even better version of this might be via heapq. heappush and heappop add some overhead, but heapq is part of Numba now (which is wasn't when we worked on PPRGo), and this should cause even fewer operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants