Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Iterators for path states #221

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ export
maximum_adjacency_visit,

# a-star, dijkstra, bellman-ford, floyd-warshall, desopo-pape, spfa
FloydWarshallIterator,
a_star,
dijkstra_shortest_paths,
bellman_ford_shortest_paths,
Expand Down
7 changes: 7 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ An abstract type that provides information from shortest paths calculations.
"""
abstract type AbstractPathState end

"""
AbstractPathIterator

An Abstract type that can be iterated over to get all paths encoded in an `AbstractPathState`.
"""
abstract type AbstractPathIterator end

"""
is_ordered(e)

Expand Down
98 changes: 77 additions & 21 deletions src/shortestpaths/floyd-warshall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ struct FloydWarshallState{T,U<:Integer} <: AbstractPathState
parents::Matrix{U}
end

"""
struct FloydWarshallIterator{T, U}

An [`AbstractPathIterator`](@ref) which, under iteration gives all shortes paths encoded in a FloydWarshallState.
When collected, returns a Matrix{Vector{U}} m, where m[s,d] is the shortes path from node s to node d if it exists,
otherwise [].
"""
struct FloydWarshallIterator{T,U<:Integer} <: AbstractPathIterator
path_state::FloydWarshallState{T,U}
end

@doc """
floyd_warshall_shortest_paths(g, distmx=weights(g))

Expand Down Expand Up @@ -69,31 +80,76 @@ function floyd_warshall_shortest_paths(
end

function enumerate_paths(
s::FloydWarshallState{T,U}, v::Integer
st::FloydWarshallState{T,U}, s::Integer, d::Integer
) where {T} where {U<:Integer}
pathinfo = s.parents[v, :]
paths = Vector{Vector{U}}()
for i in 1:length(pathinfo)
if (i == v) || (s.dists[v, i] == typemax(T))
push!(paths, Vector{U}())
else
path = Vector{U}()
currpathindex = U(i)
while currpathindex != 0
push!(path, currpathindex)
if pathinfo[currpathindex] == currpathindex
currpathindex = zero(currpathindex)
else
currpathindex = pathinfo[currpathindex]
end
pathinfo = @view st.parents[s, :]
path = Vector{U}()
if (s == d) || (st.dists[s, d] == typemax(T))
return path
else
currpathindex = U(d)
while currpathindex != 0
push!(path, currpathindex)
if pathinfo[currpathindex] == currpathindex
currpathindex = zero(currpathindex)
else
currpathindex = pathinfo[currpathindex]
end
push!(paths, reverse(path))
end
return reverse!(path)
end
end

function enumerate_paths(st::FloydWarshallState)
return [enumerate_paths(st, s) for s in 1:size(st.parents, 1)]
end

function enumerate_paths(st::FloydWarshallState, s::Integer)
return [enumerate_paths(st, s, d) for d in 1:size(st.parents, 1)]
end

function enumerate_path_into!(
pathcontainer, iter::Graphs.FloydWarshallIterator, s::Integer, d::Integer
)
if iter.path_state.parents[s, d] == 0 || s == d
return @view pathcontainer[2:1:1]
else
pathcontainer[1] = d
current_node = 2
while d != s
d = iter.path_state.parents[s, d]
pathcontainer[current_node] = d
current_node += 1
end
return @view pathcontainer[(current_node - 1):-1:1]
end
return paths
end

function enumerate_paths(s::FloydWarshallState)
return [enumerate_paths(s, v) for v in 1:size(s.parents, 1)]
function Base.iterate(iter::FloydWarshallIterator{T,U}) where {T,U<:Integer}
pathcontainer = Vector{U}(undef, size(iter.path_state.dists, 1))
pathview = enumerate_path_into!(pathcontainer, iter, 1, 1)
state = (source=1, destination=1, pathcontainer=pathcontainer)
return pathview, state
end

function Base.iterate(iter::FloydWarshallIterator{T,U}, state) where {T,U<:Integer}
a1 = axes(iter.path_state.dists, 1)
s, d, pathcontainer = state
if d + 1 in a1
d += 1
elseif s + 1 in a1
s += 1
d = 1
else
return nothing
end
pathview = enumerate_path_into!(pathcontainer, iter, s, d)
return pathview, (source=s, destination=d, pathcontainer=pathcontainer)
end
enumerate_paths(st::FloydWarshallState, s::Integer, d::Integer) = enumerate_paths(st, s)[d]

Base.IteratorSize(::FloydWarshallIterator) = Base.HasShape{2}()

Base.size(iter::FloydWarshallIterator) = size(iter.path_state.dists)
Base.length(iter::FloydWarshallIterator) = length(iter.path_state.dists)

Base.collect(iter::FloydWarshallIterator) = permutedims([collect(i) for i in iter])
41 changes: 34 additions & 7 deletions test/shortestpaths/floyd-warshall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,40 @@
@test z.dists[3, :][:] == [7, 6, 0, 11, 27]
@test z.parents[3, :][:] == [2, 3, 0, 3, 4]

@test @inferred(enumerate_paths(z))[2][2] == []
@test @inferred(enumerate_paths(z))[2][4] ==
all_paths = @inferred(enumerate_paths(z))
all_paths_flat = [all_paths...;]
@test all_paths[2][2] == []
@test all_paths[2][4] ==
enumerate_paths(z, 2)[4] ==
enumerate_paths(z, 2, 4) ==
[2, 3, 4]

z_iter = @inferred(FloydWarshallIterator(z))
first_iter = iterate(z_iter)
@test first_iter[1] == all_paths[1][1] == []
@test size(z_iter) == (5, 5)
@test length(z_iter) == 25
@test mapreduce(==, &, z_iter, all_paths_flat)

collected_iter = collect(z_iter)
@test mapreduce(&, eachrow(collected_iter), all_paths) do row, paths
mapreduce(==, &, row, paths)
end
end

g4 = path_digraph(4)
d = ones(4, 4)
for g in testdigraphs(g4)
z = @inferred(floyd_warshall_shortest_paths(g, d))
@test length(enumerate_paths(z, 4, 3)) == 0
@test length(enumerate_paths(z, 4, 1)) == 0
@test length(enumerate_paths(z, 2, 3)) == 2

z_iter = @inferred(FloydWarshallIterator(z))
iter_paths = collect(z_iter)
@test length(iter_paths[4, 3]) == 0
@test length(iter_paths[4, 1]) == 0
@test length(iter_paths[2, 3]) == 2
end

g5 = DiGraph([1 1 1 0 1; 0 1 0 1 1; 0 1 1 0 0; 1 0 1 1 0; 0 0 0 1 1])
Expand All @@ -32,15 +53,21 @@
g = SimpleGraph(2)
add_edge!(g, 1, 2)
add_edge!(g, 2, 2)
@test enumerate_paths(floyd_warshall_shortest_paths(g)) ==
Vector{Vector{Int}}[[[], [1, 2]], [[2, 1], []]]
z = floyd_warshall_shortest_paths(g)
@test enumerate_paths(z) == Vector{Vector{Int}}[[[], [1, 2]], [[2, 1], []]]

g = SimpleDiGraph(2)
@test mapreduce(
==, &, eachrow(collect(FloydWarshallIterator(z))), enumerate_paths(z)
)
add_edge!(g, 1, 1)
add_edge!(g, 1, 2)
add_edge!(g, 2, 1)
add_edge!(g, 2, 2)
@test enumerate_paths(floyd_warshall_shortest_paths(g)) ==
Vector{Vector{Int}}[[[], [1, 2]], [[2, 1], []]]
z = floyd_warshall_shortest_paths(g)
@test enumerate_paths(z) == Vector{Vector{Int}}[[[], [1, 2]], [[2, 1], []]]

@test mapreduce(
==, &, eachrow(collect(FloydWarshallIterator(z))), enumerate_paths(z)
)
end
end