Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Should Rolling.apply use pure numba rather than jitify? #18033

Open
brandon-b-miller opened this issue Feb 18, 2025 · 0 comments
Open

[FEA] Should Rolling.apply use pure numba rather than jitify? #18033

brandon-b-miller opened this issue Feb 18, 2025 · 0 comments
Assignees
Labels
feature request New feature or request numba Numba issue Python Affects Python cuDF API.

Comments

@brandon-b-miller
Copy link
Contributor

brandon-b-miller commented Feb 18, 2025

Is your feature request related to a problem? Please describe.
cuDF supports rolling.apply for executing a custom python function over the specified rolling windows. It works by taking a PTX string compiled in python by numba and handing it off to a rolling aggregation with kind == PTX, which is then used through rolling_window. Ultimately this invokes jitify and a series of parsing and compilation steps leading to the final kernel that computes the result.

Historically a similar process served to make APIs like Series.applymap work but over time we migrated to an approach that puts together the final kernel in numba rather than c++ for several reasons:

  • The parsing approach through jitify trips up in several useful cases, such as when numba delivers PTX containing multiple function definitions
  • The full numba approach allows for extension types to support null values
  • The full numba approach benefits from features such as LTO through pynvjitlink

Describe the solution you'd like
A pure numba implementation of Rolling.apply, possibly with extension type support. For non nullable data, the implementation could look something like this:

def count_if_gt_3(window):
    count = 0
    for i in window:
            if i > 3:
                    count += 1
    return count

devfunc = cuda.jit(device=True)(count_if_gt_3)
out = np.zeros(len(s))


@cuda.jit
def kernel(data, win_size, min_periods, out):
    tid = cuda.grid(1)

    start = max(0, tid - win_size + 1)
    end = tid + 1
        
    thread_win = data[start:end]

    res = devfunc(thread_win)
    out[tid] = res

The above seems to get me the correct result in a few test cases locally, of course the real impl would need more to account for min_periods, other dtypes, and more.

A null sensitive implementation seems possible as well building on a lot of what we already have. The idea revolves around assembling the data into a cuda.local.array of MaskedType and then passing that array into the UDF as written. The existing implementations of operations between MaskedTypes should take care of the rest.

@cuda.jit
def kernel(data, mask, win_size, min_periods, out_data, out_mask):
    tid = cuda.grid(1)

    start = max(0, tid - win_size + 1)
    end = tid + 1

    local = cuda.local.array(win_size, Masked(types.int64))

    # place this window of data into thread local memory as an array of MaskedTypes
    for i in range(0, end-start):
        local[i] = Masked(data[start+i], mask_get(mask, start+i))
    
    # the device function now iterates through the array of MaskedTypes
    # any operations are resolved through MaskedType's overloads
    res = devfunc(local)
    out_data[tid] = res.value
    out_mask[tid] = res.valid

Currently however creating a cuda.local.array of extension types needs a few changes (cc @gmarkall). The above would enable handling nulls explicitly within supported UDFs in conditional logic:

def count_if_gt_3(window):
    count = 0
    for i in window:
            if i != cudf.NA:
                if i > 3:
                    count += 1
            else:
                return -1
    return count

Describe alternatives you've considered
One disadvantage of this approach is that we'd need to reimplement logic around handling all of the keyword arguments currently supported for rolling without a custom aggregation, such as center and min_periods. This would have to be maintained separately from libcudf.

@brandon-b-miller brandon-b-miller added feature request New feature or request numba Numba issue Python Affects Python cuDF API. labels Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request numba Numba issue Python Affects Python cuDF API.
Projects
Status: Todo
Development

No branches or pull requests

1 participant