-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathrand_tangent.jl
118 lines (103 loc) · 3.9 KB
/
rand_tangent.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Test struct for `rand_tangent` and `difference`.
struct Bar
a::Float64
b::Int
c::Any
end
@testset "rand_tangent" begin
rng = MersenneTwister(123456)
@testset "Primal: $(typeof(x)), Tangent: $T_tangent" for (x, T_tangent) in [
# Things without sensible tangents.
("hi", NoTangent),
('a', NoTangent),
(:a, NoTangent),
(true, NoTangent),
(4, NoTangent),
(FiniteDifferences, NoTangent), # Module object
# Types (not instances of type)
(Bar, NoTangent),
(Union{Int, Bar}, NoTangent),
(Union{Int, Bar}, NoTangent),
(Vector, NoTangent),
(Vector{Float64}, NoTangent),
(Integer, NoTangent),
(Type{<:Real}, NoTangent),
# Numbers.
(5.0, Float64),
(5.0 + 0.4im, Complex{Float64}),
(big(5.0), BigFloat),
# StridedArrays.
(fill(randn(Float32)), Array{Float32, 0}),
(fill(randn(Float64)), Array{Float64, 0}),
(randn(Float32, 3), Vector{Float32}),
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
(randn(5, 4), Matrix{Float64}),
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
([randn(5, 4), 4.0], Vector{Any}),
# Co-Arrays
(randn(5)', Adjoint{Float64, Vector{Float64}}), # row-vector: special
(randn(5, 4)', Matrix{Float64}), # matrix: generic dense
(transpose(randn(5)), Transpose{Float64, Vector{Float64}}), # row-vector: special
(transpose(randn(5, 4)), Matrix{Float64}), # matrix: generic dense
# AbstactArrays of non-perturbable types
(1:10, NoTangent),
(1:2:10, NoTangent),
([false, true], NoTangent),
# Tuples.
((4.0, ), Tangent{Tuple{Float64}}),
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
((false, true), NoTangent),
(Tuple{}(), NoTangent),
# NamedTuples.
((a=4.0, ), Tangent{NamedTuple{(:a,), Tuple{Float64}}}),
((a=5.0, b=1), Tangent{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
((a=false, b=true), NoTangent),
((;), NoTangent),
# structs.
(Bar(5.0, 4, rand(rng, 3)), Tangent{Bar}),
(Bar(4.0, 3, Bar(5.0, 2, 4)), Tangent{Bar}),
(sin, NoTangent),
# all fields NoTangent implies NoTangent
(Pair(:a, "b"), NoTangent),
(CartesianIndex(2, 3), NoTangent),
# LinearAlgebra types
(
UpperTriangular(randn(3, 3)),
UpperTriangular{Float64, Matrix{Float64}},
),
(
Diagonal(randn(2)),
Diagonal{Float64, Vector{Float64}},
),
(
Symmetric(randn(2, 2)),
Symmetric{Float64, Matrix{Float64}},
),
(
Hermitian(randn(ComplexF64, 1, 1)),
Hermitian{ComplexF64, Matrix{ComplexF64}},
),
]
@test rand_tangent(rng, x) isa T_tangent
@test rand_tangent(x) isa T_tangent
end
@testset "erroring cases" begin
# Ensure struct fallback errors for non-struct types.
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
end
@testset "compsition of addition" begin
x = Bar(1.5, 2, Bar(1.1, 3, [1.7, 1.4, 0.9]))
@test x + rand_tangent(x) isa typeof(x)
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
end
# Julia 1.6 changed to using Ryu printing algorithm and seems better at printing short
VERSION >= v"1.6" && @testset "niceness of printing" begin
rng = MersenneTwister(1)
for i in 1:50
@test length(string(rand_tangent(rng, 1.0))) <= 6
@test length(string(rand_tangent(rng, 1.0 + 1.0im))) <= 12
@test length(string(rand_tangent(rng, 1f0))) <= 9
@test length(string(rand_tangent(rng, big"1.0"))) <= 9
end
end
end