Skip to content

Commit 3657a02

Browse files
committed
Add skip predicate to inrange, fixes #53
1 parent 012a3da commit 3657a02

7 files changed

+69
-35
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ idxs, dists = knn(kdtree, point, k, true)
9393
A range search finds all neighbors within the range `r` of given point(s).
9494
This is done with the method:
9595
```jl
96-
inrange(tree, points, r, sortres = false) -> idxs
96+
inrange(tree, points, r, sortres = false, skip = always_false) -> idxs
9797
```
9898
Note that for performance reasons the distances are not returned. The arguments to `inrange` are the same as for `knn` except that `sortres` just sorts the returned index vector.
9999

src/ball_tree.jl

+9-7
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,19 @@ end
189189
function _inrange(tree::BallTree{V},
190190
point::AbstractVector,
191191
radius::Number,
192-
idx_in_ball::Vector{Int}) where {V}
192+
idx_in_ball::Vector{Int},
193+
skip::Function) where {V}
193194
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
194-
inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
195+
inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip) # Call the recursive range finder
195196
return
196197
end
197198

198199
function inrange_kernel!(tree::BallTree,
199200
index::Int,
200201
point::AbstractVector,
201202
query_ball::HyperSphere,
202-
idx_in_ball::Vector{Int})
203+
idx_in_ball::Vector{Int},
204+
skip::Function)
203205
@NODE 1
204206

205207
if index > length(tree.hyper_spheres)
@@ -216,17 +218,17 @@ function inrange_kernel!(tree::BallTree,
216218

217219
# At a leaf node, check all points in the leaf node
218220
if isleaf(tree.tree_data.n_internal_nodes, index)
219-
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true)
221+
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true, skip)
220222
return
221223
end
222224

223225
# The query ball encloses the sub tree bounding sphere. Add all points in the
224226
# sub tree without checking the distance function.
225227
if encloses(tree.metric, sphere, query_ball)
226-
addall(tree, index, idx_in_ball)
228+
addall(tree, index, idx_in_ball, skip)
227229
else
228230
# Recursively call the left and right sub tree.
229-
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
230-
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
231+
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip)
232+
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip)
231233
end
232234
end

src/brute_tree.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,23 @@ end
5555
function _inrange(tree::BruteTree,
5656
point::AbstractVector,
5757
radius::Number,
58-
idx_in_ball::Vector{Int})
59-
inrange_kernel!(tree, point, radius, idx_in_ball)
58+
idx_in_ball::Vector{Int},
59+
skip::Function)
60+
inrange_kernel!(tree, point, radius, idx_in_ball, skip)
6061
return
6162
end
6263

6364

6465
function inrange_kernel!(tree::BruteTree,
6566
point::AbstractVector,
6667
r::Number,
67-
idx_in_ball::Vector{Int})
68+
idx_in_ball::Vector{Int},
69+
skip::Function)
6870
for i in 1:length(tree.data)
71+
if skip(i)
72+
continue
73+
end
74+
6975
@POINT 1
7076
d = evaluate(tree.metric, tree.data[i], point)
7177
if d <= r

src/inrange.jl

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))
22

33
"""
4-
inrange(tree::NNTree, points, radius [, sortres=false]) -> indices
4+
inrange(tree::NNTree, points, radius [, sortres=false, skip=always_false]) -> indices
55
66
Find all the points in the tree which is closer than `radius` to `points`. If
7-
`sortres = true` the resulting indices are sorted.
7+
`sortres = true` the resulting indices are sorted. `skip` is an optional predicate
8+
to determine if a point that would be returned should be skipped.
89
"""
910
function inrange(tree::NNTree,
1011
points::Vector{T},
1112
radius::Number,
12-
sortres=false) where {T <: AbstractVector}
13+
sortres=false,
14+
skip::Function=always_false) where {T <: AbstractVector}
1315
check_input(tree, points)
1416
check_radius(radius)
1517

1618
idxs = [Vector{Int}() for _ in 1:length(points)]
1719

1820
for i in 1:length(points)
19-
inrange_point!(tree, points[i], radius, sortres, idxs[i])
21+
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip)
2022
end
2123
return idxs
2224
end
2325

24-
function inrange_point!(tree, point, radius, sortres, idx)
25-
_inrange(tree, point, radius, idx)
26+
function inrange_point!(tree, point, radius, sortres, idx, skip)
27+
_inrange(tree, point, radius, idx, skip)
2628
if tree.reordered
2729
@inbounds for j in 1:length(idx)
2830
idx[j] = tree.indices[idx[j]]
@@ -32,21 +34,21 @@ function inrange_point!(tree, point, radius, sortres, idx)
3234
return
3335
end
3436

35-
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number}
37+
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number}
3638
check_input(tree, point)
3739
check_radius(radius)
3840
idx = Int[]
39-
inrange_point!(tree, point, radius, sortres, idx)
41+
inrange_point!(tree, point, radius, sortres, idx, skip)
4042
return idx
4143
end
4244

43-
function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false) where {V, T <: Number}
45+
function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number}
4446
dim = size(point, 1)
4547
npoints = size(point, 2)
4648
if isbits(T)
4749
new_data = reinterpret(SVector{dim,T}, point, (length(point) ÷ dim,))
4850
else
4951
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
5052
end
51-
inrange(tree, new_data, radius, sortres)
53+
inrange(tree, new_data, radius, sortres, skip)
5254
end

src/kd_tree.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,11 @@ end
203203
function _inrange(tree::KDTree,
204204
point::AbstractVector,
205205
radius::Number,
206-
idx_in_ball = Int[])
206+
idx_in_ball = Int[],
207+
skip::Function = always_false)
207208
init_min = get_min_distance(tree.hyper_rec, point)
208209
inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball,
209-
init_min)
210+
init_min, skip)
210211
return
211212
end
212213

@@ -216,7 +217,8 @@ function inrange_kernel!(tree::KDTree,
216217
point::AbstractVector,
217218
r::Number,
218219
idx_in_ball::Vector{Int},
219-
min_dist)
220+
min_dist,
221+
skip::Function)
220222
@NODE 1
221223
# Point is outside hyper rectangle, skip the whole sub tree
222224
if min_dist > r
@@ -225,7 +227,7 @@ function inrange_kernel!(tree::KDTree,
225227

226228
# At a leaf node. Go through all points in node and add those in range
227229
if isleaf(tree.tree_data.n_internal_nodes, index)
228-
add_points_inrange!(idx_in_ball, tree, index, point, r, false)
230+
add_points_inrange!(idx_in_ball, tree, index, point, r, false, skip)
229231
return
230232
end
231233

@@ -247,7 +249,7 @@ function inrange_kernel!(tree::KDTree,
247249
ddiff = max(zero(lo - p_dim), lo - p_dim)
248250
end
249251
# Call closer sub tree
250-
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist)
252+
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist, skip)
251253

252254
# TODO: We could potentially also keep track of the max distance
253255
# between the point and the hyper rectangle and add the whole sub tree
@@ -259,5 +261,5 @@ function inrange_kernel!(tree::KDTree,
259261
ddiff_pow = eval_pow(M, ddiff)
260262
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow)
261263
new_min = eval_reduce(M, min_dist, diff_tot)
262-
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min)
264+
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min, skip)
263265
end

src/tree_ops.jl

+17-8
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ end
9494
tree::NNTree, index::Int, point::AbstractVector,
9595
do_end::Bool, skip::F) where {F}
9696
for z in get_leaf_range(tree.tree_data, index)
97+
if skip(tree.indices[z])
98+
continue
99+
end
100+
97101
@POINT 1
98102
idx = tree.reordered ? z : tree.indices[z]
99103
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
100104
if dist_d <= best_dists[1]
101-
if skip(tree.indices[z])
102-
continue
103-
end
104-
105105
best_dists[1] = dist_d
106106
best_idxs[1] = idx
107107
percolate_down!(best_dists, best_idxs, dist_d, idx)
@@ -116,8 +116,13 @@ end
116116
# This will probably prevent SIMD and other optimizations so some care is needed
117117
# to evaluate if it is worth it.
118118
@inline function add_points_inrange!(idx_in_ball::Vector{Int}, tree::NNTree,
119-
index::Int, point::AbstractVector, r::Number, do_end::Bool)
119+
index::Int, point::AbstractVector, r::Number,
120+
do_end::Bool, skip::Function)
120121
for z in get_leaf_range(tree.tree_data, index)
122+
if skip(tree.indices[z])
123+
continue
124+
end
125+
121126
@POINT 1
122127
idx = tree.reordered ? z : tree.indices[z]
123128
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
@@ -129,18 +134,22 @@ end
129134

130135
# Add all points in this subtree since we have determined
131136
# they are all within the desired range
132-
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int})
137+
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int}, skip::Function)
133138
tree_data = tree.tree_data
134139
@NODE 1
135140
if isleaf(tree.tree_data.n_internal_nodes, index)
136141
for z in get_leaf_range(tree.tree_data, index)
142+
if skip(tree.indices[z])
143+
continue
144+
end
145+
137146
@POINT_UNCHECKED 1
138147
idx = tree.reordered ? z : tree.indices[z]
139148
push!(idx_in_ball, idx)
140149
end
141150
return
142151
else
143-
addall(tree, getleft(index), idx_in_ball)
144-
addall(tree, getright(index), idx_in_ball)
152+
addall(tree, getleft(index), idx_in_ball, skip)
153+
addall(tree, getright(index), idx_in_ball, skip)
145154
end
146155
end

test/test_inrange.jl

+13
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,16 @@
4444
end
4545
end
4646
end
47+
48+
@testset "inrange skip" begin
49+
@testset "tree type" for TreeType in trees_with_brute
50+
data = rand(2, 1000)
51+
tree = TreeType(data)
52+
id = 123
53+
54+
idxs = inrange(tree, data[:, id], 2, true)
55+
@test id in idxs
56+
idxs = inrange(tree, data[:, id], 2, true, i -> i == id)
57+
@test !(id in idxs)
58+
end
59+
end

0 commit comments

Comments
 (0)