diff --git a/base/sparse/abstractsparse.jl b/base/sparse/abstractsparse.jl index 5a360376841de..bb0feaf29cfea 100644 --- a/base/sparse/abstractsparse.jl +++ b/base/sparse/abstractsparse.jl @@ -42,3 +42,24 @@ function Base.reinterpret(::Type, A::AbstractSparseArray) Try reinterpreting the value itself instead. """) end + +# The following two methods should be overloaded by concrete types to avoid +# allocating the I = find(...) +_sparse_findnextnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : zero(indtype(v))) +_sparse_findprevnz(v::AbstractSparseArray, i::Integer) = (I = find(!iszero, v); n = searchsortedlast(I, i); !iszero(n) ? I[n] : zero(indtype(v))) + +function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Integer) + j = _sparse_findnextnz(v, i) + while !iszero(j) && !f(v[j]) + j = _sparse_findnextnz(v, j+1) + end + return j +end + +function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Integer) + j = _sparse_findprevnz(v, i) + while !iszero(j) && !f(v[j]) + j = _sparse_findprevnz(v, j-1) + end + return j +end diff --git a/base/sparse/sparse.jl b/base/sparse/sparse.jl index 8c9133aabe6ab..28ca964b7bdc0 100644 --- a/base/sparse/sparse.jl +++ b/base/sparse/sparse.jl @@ -15,9 +15,9 @@ import Base.LinAlg: mul!, ldiv!, rdiv! import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh, atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot, cotd, coth, count, csc, cscd, csch, adjoint!, diag, diff, done, dot, eig, - exp10, exp2, findn, floor, hash, indmin, inv, issymmetric, istril, istriu, - log10, log2, lu, next, sec, secd, sech, show, sin, - sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan, + exp10, exp2, findn, findprev, findnext, floor, hash, indmin, inv, + issymmetric, istril, istriu, log10, log2, lu, next, sec, secd, sech, show, + sin, sinc, sind, sinh, sinpi, squeeze, start, sum, summary, tan, tand, tanh, trace, transpose!, tril!, triu!, trunc, vecnorm, abs, abs2, broadcast, ceil, complex, cond, conj, convert, copy, copyto!, adjoint, diagm, exp, expm1, factorize, find, findmax, findmin, findnz, float, getindex, diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 1d2516c01c133..205752af48225 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1315,6 +1315,42 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} return (I, J, V) end +function _sparse_findnextnz(m::SparseMatrixCSC, i::Integer) + if i > length(m) + return zero(indtype(m)) + end + row, col = Tuple(CartesianIndices(m)[i]) + lo, hi = m.colptr[col], m.colptr[col+1] + n = searchsortedfirst(m.rowval, row, lo, hi-1, Base.Order.Forward) + if lo <= n <= hi-1 + return LinearIndices(m)[m.rowval[n], col] + end + nextcol = findnext(c->(c>hi), m.colptr, col+1) + if iszero(nextcol) + return zero(indtype(m)) + end + nextlo = m.colptr[nextcol-1] + return LinearIndices(m)[m.rowval[nextlo], nextcol-1] +end + +function _sparse_findprevnz(m::SparseMatrixCSC, i::Integer) + if iszero(i) + return zero(indtype(m)) + end + row, col = Tuple(CartesianIndices(m)[i]) + lo, hi = m.colptr[col], m.colptr[col+1] + n = searchsortedlast(m.rowval, row, lo, hi-1, Base.Order.Forward) + if lo <= n <= hi-1 + return LinearIndices(m)[m.rowval[n], col] + end + prevcol = findprev(c->(c length(v.nzind) + return zero(indtype(v)) + else + return v.nzind[n] + end +end + +function _sparse_findprevnz(v::SparseVector, i::Integer) + n = searchsortedlast(v.nzind, i) + if iszero(n) + return zero(indtype(v)) + else + return v.nzind[n] + end +end + ### Generic functions operating on AbstractSparseVector ### getindex diff --git a/test/sparse/sparse.jl b/test/sparse/sparse.jl index de44d59d8dbd1..9a799de6f9a9d 100644 --- a/test/sparse/sparse.jl +++ b/test/sparse/sparse.jl @@ -2171,6 +2171,37 @@ end @test count(SparseMatrixCSC(2, 2, Int[1, 2, 3], Int[1, 2], Bool[true, true, true])) == 2 end +@testset "sparse findprev/findnext operations" begin + + x = [0,0,0,0,1,0,1,0,1,1,0] + x_sp = sparse(x) + + for i=1:length(x) + @test findnext(!iszero, x,i) == findnext(!iszero, x_sp,i) + @test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i) + end + + y = [0 0 0 0 0; + 1 0 1 0 0; + 1 0 0 0 1; + 0 0 1 0 0; + 1 0 1 1 0] + y_sp = sparse(y) + + for i=1:length(y) + @test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i) + @test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i) + end + + z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1)) + z = collect(z_sp) + + for i=1:length(z) + @test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i) + @test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i) + end +end + # #20711 @testset "vec returns a view" begin local A = sparse(Matrix(1.0I, 3, 3))