From 0ef4e34e5a62f481eeeb5c885049a97b28c658a5 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Thu, 24 Aug 2023 18:17:29 +0100 Subject: [PATCH 1/8] Add GPU array tests --- test/projection.jl | 102 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 21 deletions(-) diff --git a/test/projection.jl b/test/projection.jl index d364631fc..bfa82f414 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,6 +1,7 @@ using ChainRulesCore, Test using LinearAlgebra, SparseArrays using OffsetArrays, StaticArrays, BenchmarkTools +using JLArrays # Like ForwardDiff.jl's Dual struct Dual{T<:Real} <: Real @@ -50,7 +51,7 @@ struct NoSuperType end # real & complex @test ProjectTo(1.0 + 1im)(Dual(1.0, 2.0)) isa Complex{<:Dual} @test ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa - Complex{<:Dual} + Complex{<:Dual} @test ProjectTo(1.0)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0))) isa Dual # Tangent @@ -143,7 +144,7 @@ struct NoSuperType end @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} - + @test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero end @@ -154,7 +155,7 @@ struct NoSuperType end @test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent @test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any} end - @test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector + @test pt1([1]) == Tangent{Tuple{Float64}}(1.0) # accepts Vector @test @inferred(pt1(NoTangent())) === NoTangent() @test @inferred(pt1(ZeroTangent())) === ZeroTangent() @test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero @@ -163,7 +164,9 @@ struct NoSuperType end @test_throws Exception pt1([]) pt3 = ProjectTo(([1, 2, 3], false, :gamma)) # partly non-differentiable - @test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int}, Bool, Symbol}}([1.0, 2.0, 3.0], NoTangent(), NoTangent()) + @test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int},Bool,Symbol}}( + [1.0, 2.0, 3.0], NoTangent(), NoTangent() + ) @test ProjectTo((true, [false])) isa ProjectTo{NoTangent} end @@ -216,7 +219,7 @@ struct NoSuperType end @testset "UniformScaling" begin @test ProjectTo(I)(123) === NoTangent() @test ProjectTo(2 * I)(I * 3im) === 0.0 * I - @test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ = 6)) === (6.0 + 0.0im) * I + @test ProjectTo((4 + 5im) * I)(Tangent{typeof(im * I)}(; λ=6)) === (6.0 + 0.0im) * I @test ProjectTo(7 * I)(Tangent{typeof(2I)}()) == ZeroTangent() end @@ -375,29 +378,86 @@ struct NoSuperType end pvec3 = ProjectTo([1, 2, 3]) @test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,) @test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test - @test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector + @test pvec3(OffsetArray(rand(3, 1), 0:2, 0:0)) isa Vector end ##### ##### `StaticArrays` ##### - @testset "StaticArrays" begin - # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) - # implies a check, and reshape will wrap a Vector into a static SizedVector: - pstat = ProjectTo(SA[1, 2, 3]) - @test axes(pstat(rand(3))) === (SOneTo(3),) - - # This recurses into structured arrays: - pst = ProjectTo(transpose(SA[1, 2, 3])) - @test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3)) - @test pst(rand(1,3)) isa Transpose - - # When the argument is an ordinary Array, static gradients are allowed to pass, - # like FillArrays. Collecting to an Array would cost a copy. - pvec3 = ProjectTo([1, 2, 3]) - @test pvec3(SA[1, 2, 3]) isa StaticArray + @testset "StaticArrays" begin + # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx) + # implies a check, and reshape will wrap a Vector into a static SizedVector: + pstat = ProjectTo(SA[1, 2, 3]) + @test axes(pstat(rand(3))) === (SOneTo(3),) + + # This recurses into structured arrays: + pst = ProjectTo(transpose(SA[1, 2, 3])) + @test axes(pst(rand(1, 3))) === (SOneTo(1), SOneTo(3)) + @test pst(rand(1, 3)) isa Transpose + + # When the argument is an ordinary Array, static gradients are allowed to pass, + # like FillArrays. Collecting to an Array would cost a copy. + pvec3 = ProjectTo([1, 2, 3]) + @test pvec3(SA[1, 2, 3]) isa StaticArray + end + + ##### + ##### `GPU arrays` + ##### + + # issue #624 + @testset "GPUArrays" begin + JLVector = JLArray{T,1} where {T} + JLMatrix = JLArray{T,2} where {T} + + pvec3 = ProjectTo(JLArray([1, 2, 3])) + @test pvec3(JLArray(1.0:3.0)) == JLArray(1.0:3.0) + @test pvec3(JLArray(1:3)) == JLArray(1.0:3.0) # would prefer ===, map(Float64, dx) would do that, not important + @test pvec3(JLArray([1, 2, 3 + 4im])) == JLArray(1:3) + @test eltype(pvec3(JLArray([1, 2, 3.0f0]))) === Float64 + + # reshape + @test pvec3(reshape(JLArray([1, 2, 3]), 3, 1)) isa JLVector + @test_throws DimensionMismatch pvec3(reshape(JLArray([1, 2, 3]), 1, 3)) + @test_throws DimensionMismatch pvec3(JLArray([1, 2, 3, 4])) + + pmat = ProjectTo(JLArray(rand(2, 2) .+ im)) + @test pmat(JLArray([1 2; 3 4.0+5im])') isa Adjoint # pass-through + @test pmat(JLArray([1 2; 3 4])') isa JLMatrix # broadcast type change + + pmat2 = ProjectTo(JLArray(rand(2, 2))') + @test pmat2(JLArray([1 2; 3 4.0+5im])) isa JLMatrix # adjoint matrices are not re-created + + prow = ProjectTo(JLArray([1im 2 3im])) + @test prow(transpose(JLArray([1, 2, 3 + 4.0im]))) == JLArray([1 2 3 + 4im]) + @test prow(transpose(JLArray([1, 2, 3 + 4.0im]))) isa JLMatrix # row vectors may not pass through + @test prow(adjoint(JLArray([1, 2, 3 + 5im]))) == JLArray([1 2 3 - 5im]) + @test prow(adjoint(JLArray([1, 2, 3]))) isa JLMatrix + + # some bugs + @test pvec3(JLArray(fill(NoTangent(), 3))) === NoTangent() #410, was an array of such + @test ProjectTo(JLArray([pi]))(JLArray([1])) isa JLVector{Int} #423, was Irrational -> Bool -> NoTangent + + # adjoint vectors + @testset "GPUArrays: $adj vectors" for adj in [transpose, adjoint] + padj = ProjectTo(adj(JLArray([1, 2, 3]))) + adjT = typeof(adj(JLArray([1, 2, 3.0]))) + @test padj(transpose(JLArray(1:3))) isa adjT + @test padj(JLArray([4 5 6 + 7im])) isa adjT + @test padj(JLArray([4.0 5.0 6.0])) isa adjT + + @test_throws DimensionMismatch padj(JLArray([1, 2, 3])) + @test_throws DimensionMismatch padj(JLArray([1 2 3]')) + @test_throws DimensionMismatch padj(JLArray([1 2 3 4])) + + padj_complex = ProjectTo(adj(JLArray([1, 2, 3 + 4im]))) + @test padj_complex(JLArray([4 5 6 + 7im])) == JLArray([4 5 6 + 7im]) + @test padj_complex(transpose(JLArray([4, 5, 6 + 7im]))) == + JLArray([4 5 6 + 7im]) + @test padj_complex(adjoint(JLArray([4, 5, 6 + 7im]))) == JLArray([4 5 6 - 7im]) end + end ##### ##### `ChainRulesCore` From 553223989f18a796ac3c8244f78dd42358a96121 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 09:34:49 +0100 Subject: [PATCH 2/8] Add extra tests --- Project.toml | 10 +++++++++- test/projection.jl | 7 +++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 20984db45..81e94dcf3 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,15 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "StaticArrays"] +test = [ + "Test", + "BenchmarkTools", + "FiniteDifferences", + "OffsetArrays", + "StaticArrays", + "JLArrays", +] diff --git a/test/projection.jl b/test/projection.jl index bfa82f414..91511797b 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -456,6 +456,13 @@ struct NoSuperType end @test padj_complex(transpose(JLArray([4, 5, 6 + 7im]))) == JLArray([4 5 6 + 7im]) @test padj_complex(adjoint(JLArray([4, 5, 6 + 7im]))) == JLArray([4 5 6 - 7im]) + + # issue #410 + @test padj(JLArray([NoTangent() NoTangent() NoTangent()])) === NoTangent() + + @test ProjectTo(adj(JLArray([true, false])))(JLArray([1 2])) isa AbstractZero + @test ProjectTo(adj([JLArray([true]), JLArray([false])])) isa + ProjectTo{<:AbstractZero} end end From 29f76f7808f71e63448ff80a73432d17bfa2f899 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 11:12:40 +0100 Subject: [PATCH 3/8] Fix proj of trans, adj GPUArrays onto GPUArrays --- Project.toml | 3 +++ src/ChainRulesCore.jl | 1 + src/projection.jl | 9 ++++++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 81e94dcf3..4c3b7433c 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.16.0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" [compat] BenchmarkTools = "0.5" @@ -13,6 +14,8 @@ Compat = "2, 3, 4" FiniteDifferences = "0.10" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" +GPUArraysCore = "0.1" +JLArrays = "0.1" julia = "1.6" [extras] diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index b81ab4fba..864f06289 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -3,6 +3,7 @@ using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, mate using Base.Meta using LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC +using GPUArraysCore using Compat: hasfield, hasproperty export frule, rrule # core function diff --git a/src/projection.jl b/src/projection.jl index 811802536..cf0960908 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -248,6 +248,14 @@ function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) return project(reshape(vec(dx), 1, :)) end +# Nested GPUArray wrappers lead to scalar indexing, try to prevent that: +function (project::ProjectTo{AbstractArray})(dx::Transpose{T,A}) where {T,A<:AbstractGPUVector} + return project(copy(reshape(vec(dx), 1, :))) +end +function (project::ProjectTo{AbstractArray})(dx::Adjoint{T,A}) where {T,A<:AbstractGPUVector} + return project(copy(reshape(conj.(adjoint(dx)), 1, :))) +end + # Zero-dimensional arrays -- these have a habit of going missing, # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers @@ -385,7 +393,6 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) end end - ##### ##### `LinearAlgebra` ##### From 3397e67b9bccb73aaec0d9f5e3425b4a8dd0b4d0 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 13:23:07 +0100 Subject: [PATCH 4/8] Fix proj onto transposed and adjointed GPUArrays --- src/projection.jl | 53 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index cf0960908..435cc2c6c 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -227,7 +227,11 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} throw(_projection_mismatch(project.axes, size(dx))) end end - reshape(dx, project.axes) + if dx isa AbstractGPUArray + copy(reshape(dx, project.axes)) + else + reshape(dx, project.axes) + end end # Then deal with the elements. One projector if AbstractArray{<:Number}, # or one per element for arrays of anything else, including arrays of arrays: @@ -248,14 +252,6 @@ function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) return project(reshape(vec(dx), 1, :)) end -# Nested GPUArray wrappers lead to scalar indexing, try to prevent that: -function (project::ProjectTo{AbstractArray})(dx::Transpose{T,A}) where {T,A<:AbstractGPUVector} - return project(copy(reshape(vec(dx), 1, :))) -end -function (project::ProjectTo{AbstractArray})(dx::Adjoint{T,A}) where {T,A<:AbstractGPUVector} - return project(copy(reshape(conj.(adjoint(dx)), 1, :))) -end - # Zero-dimensional arrays -- these have a habit of going missing, # although really Ref() is probably a better structure. function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers @@ -620,3 +616,42 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) invoke(project, Tuple{AbstractArray}, dx) end end + +##### +##### `GPUArrays` +##### + +# Row vectors aren't acceptable as gradients for 1-row matrices: +# Nested GPUArray wrappers lead to scalar indexing, try to prevent that: +function (project::ProjectTo{AbstractArray})( + dx::Transpose{T,A} +) where {T,A<:AbstractGPUVector} + return project(copy(reshape(vec(dx), 1, :))) +end +function (project::ProjectTo{AbstractArray})( + dx::Adjoint{T,A} +) where {T,A<:AbstractGPUVector} + return project(copy(reshape(conj(adjoint(dx)), 1, :))) +end + +AdjOrTransAbsGPUVec = Union{Adjoint{T,A},Transpose{T,A}} where {T,A<:AbstractGPUVector} +function (project::ProjectTo{Adjoint})(dx::AdjOrTransAbsGPUVec) + return adjoint(project.parent(conj(transpose(dx)))) +end +function (project::ProjectTo{Adjoint})(dx::AbstractGPUArray) + if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1]) + throw(_projection_mismatch((1:1, project.parent.axes...), size(dx))) + end + dy = eltype(dx) <: Real ? copy(vec(dx)) : copy(adjoint(dx)) + return adjoint(project.parent(dy)) +end +function (project::ProjectTo{Transpose})(dx::AdjOrTransAbsGPUVec) + return transpose(project.parent(conj(adjoint(dx)))) +end +function (project::ProjectTo{Transpose})(dx::AbstractGPUArray) + if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1]) + throw(_projection_mismatch((1:1, project.parent.axes...), size(dx))) + end + dy = eltype(dx) <: Number ? copy(vec(dx)) : copy(transpose(dx)) + return transpose(project.parent(dy)) +end \ No newline at end of file From 49df74d2b2e59b8baeb9818152d25afcfc7db7fb Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 13:35:30 +0100 Subject: [PATCH 5/8] Add comment and issue link --- src/projection.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index 435cc2c6c..3ad67ddd0 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -621,6 +621,8 @@ end ##### `GPUArrays` ##### +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/624 + # Row vectors aren't acceptable as gradients for 1-row matrices: # Nested GPUArray wrappers lead to scalar indexing, try to prevent that: function (project::ProjectTo{AbstractArray})( @@ -634,6 +636,8 @@ function (project::ProjectTo{AbstractArray})( return project(copy(reshape(conj(adjoint(dx)), 1, :))) end +# Make sure wrappers are either cancelled or materialized to maintain a maximum +# wrapper depth of 1: AdjOrTransAbsGPUVec = Union{Adjoint{T,A},Transpose{T,A}} where {T,A<:AbstractGPUVector} function (project::ProjectTo{Adjoint})(dx::AdjOrTransAbsGPUVec) return adjoint(project.parent(conj(transpose(dx)))) From 4ad78b1889b7ca5c2fc7aaf3b4415073e9b8afc0 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 13:49:34 +0100 Subject: [PATCH 6/8] Add comment and issue link near separate edit --- src/projection.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 3ad67ddd0..6e232a809 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -227,7 +227,9 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} throw(_projection_mismatch(project.axes, size(dx))) end end - if dx isa AbstractGPUArray + # Reshape, copying to remove the wrapper if a GPUArray, see + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/624 + if dx isa AbstractGPUArray copy(reshape(dx, project.axes)) else reshape(dx, project.axes) From 0c6092ec03a46ad8ce1214ee4500b198f2237d89 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 13:51:04 +0100 Subject: [PATCH 7/8] Edit comment --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 6e232a809..2ce03d7dc 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -638,7 +638,7 @@ function (project::ProjectTo{AbstractArray})( return project(copy(reshape(conj(adjoint(dx)), 1, :))) end -# Make sure wrappers are either cancelled or materialized to maintain a maximum +# Make sure wrappers either cancel out or are materialized to maintain a maximum # wrapper depth of 1: AdjOrTransAbsGPUVec = Union{Adjoint{T,A},Transpose{T,A}} where {T,A<:AbstractGPUVector} function (project::ProjectTo{Adjoint})(dx::AdjOrTransAbsGPUVec) From 699ada3d766e1d918fac3c025855f1c691f75f61 Mon Sep 17 00:00:00 2001 From: DomCRose Date: Mon, 4 Sep 2023 14:56:54 +0100 Subject: [PATCH 8/8] Fix whitespace --- src/projection.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 2ce03d7dc..6c58cc0e2 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -229,7 +229,7 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M} end # Reshape, copying to remove the wrapper if a GPUArray, see # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/624 - if dx isa AbstractGPUArray + if dx isa AbstractGPUArray copy(reshape(dx, project.axes)) else reshape(dx, project.axes) @@ -660,4 +660,4 @@ function (project::ProjectTo{Transpose})(dx::AbstractGPUArray) end dy = eltype(dx) <: Number ? copy(vec(dx)) : copy(transpose(dx)) return transpose(project.parent(dy)) -end \ No newline at end of file +end