Skip to content

Commit 47663bd

Browse files
tpappvtjnash
andauthored
Normalize indices in promote_shape error messages (#41311)
Seeing implementation details like `Base.OneTo` in error messages may be confusing to some users (cf discussion in #39242, [discourse](https://discourse.julialang.org/t/promote-shape-dimension-mismatch/57529/)). This PR turns ```julia julia> ones(2, 3) + ones(3, 2) ERROR: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(2), Base.OneTo(3)), b has dims (Base.OneTo(3), Base.OneTo(2)), mismatch at 1") ``` into ```julia julia> ones(2, 3) + ones(3, 2) ERROR: DimensionMismatch("dimensions must match: a has size (2, 3), b has size (3, 2), mismatch at 1") ``` Fixes #40118. (This is basically #40124, but redone because I made a mess rebasing). --------- Co-authored-by: Jameson Nash <[email protected]>
1 parent 831cc14 commit 47663bd

File tree

4 files changed

+46
-24
lines changed

4 files changed

+46
-24
lines changed

base/indices.jl

+36-21
Original file line numberDiff line numberDiff line change
@@ -106,26 +106,49 @@ IndexStyle(::IndexStyle, ::IndexStyle) = IndexCartesian()
106106

107107
promote_shape(::Tuple{}, ::Tuple{}) = ()
108108

109-
function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})
110-
if a[1] != b[1]
111-
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
109+
# Consistent error message for promote_shape mismatch, hiding type details like
110+
# OneTo. When b ≡ nothing, it is omitted; i can be supplied for an index.
111+
function throw_promote_shape_mismatch(a::Tuple, b::Union{Nothing,Tuple}, i = nothing)
112+
if a isa Tuple{Vararg{Base.OneTo}} && (b === nothing || b isa Tuple{Vararg{Base.OneTo}})
113+
a = map(lastindex, a)::Dims
114+
b === nothing || (b = map(lastindex, b)::Dims)
115+
end
116+
_has_axes = !(a isa Dims && (b === nothing || b isa Dims))
117+
if _has_axes
118+
_normalize(d) = map(x -> firstindex(x):lastindex(x), d)
119+
a = _normalize(a)
120+
b === nothing || (b = _normalize(b))
121+
_things = "axes "
122+
else
123+
_things = "size "
124+
end
125+
msg = IOBuffer()
126+
print(msg, "a has ", _things)
127+
print(msg, a)
128+
if b nothing
129+
print(msg, ", b has ", _things)
130+
print(msg, b)
112131
end
132+
if i nothing
133+
print(msg, ", mismatch at dim ", i)
134+
end
135+
throw(DimensionMismatch(String(take!(msg))))
136+
end
137+
138+
function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})
139+
a[1] != b[1] && throw_promote_shape_mismatch(a, b)
113140
return a
114141
end
115142

116143
function promote_shape(a::Tuple{Int,Int}, b::Tuple{Int,})
117-
if a[1] != b[1] || a[2] != 1
118-
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
119-
end
144+
(a[1] != b[1] || a[2] != 1) && throw_promote_shape_mismatch(a, b)
120145
return a
121146
end
122147

123148
promote_shape(a::Tuple{Int,}, b::Tuple{Int,Int}) = promote_shape(b, a)
124149

125150
function promote_shape(a::Tuple{Int, Int}, b::Tuple{Int, Int})
126-
if a[1] != b[1] || a[2] != b[2]
127-
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b"))
128-
end
151+
(a[1] != b[1] || a[2] != b[2]) && throw_promote_shape_mismatch(a, b)
129152
return a
130153
end
131154

@@ -153,14 +176,10 @@ function promote_shape(a::Dims, b::Dims)
153176
return promote_shape(b, a)
154177
end
155178
for i=1:length(b)
156-
if a[i] != b[i]
157-
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b, mismatch at $i"))
158-
end
179+
a[i] != b[i] && throw_promote_shape_mismatch(a, b, i)
159180
end
160181
for i=length(b)+1:length(a)
161-
if a[i] != 1
162-
throw(DimensionMismatch("dimensions must match: a has dims $a, must have singleton at dim $i"))
163-
end
182+
a[i] != 1 && throw_promote_shape_mismatch(a, nothing, i)
164183
end
165184
return a
166185
end
@@ -174,14 +193,10 @@ function promote_shape(a::Indices, b::Indices)
174193
return promote_shape(b, a)
175194
end
176195
for i=1:length(b)
177-
if a[i] != b[i]
178-
throw(DimensionMismatch("dimensions must match: a has dims $a, b has dims $b, mismatch at $i"))
179-
end
196+
a[i] != b[i] && throw_promote_shape_mismatch(a, b, i)
180197
end
181198
for i=length(b)+1:length(a)
182-
if a[i] != 1:1
183-
throw(DimensionMismatch("dimensions must match: a has dims $a, must have singleton at dim $i"))
184-
end
199+
a[i] != 1:1 && throw_promote_shape_mismatch(a, nothing, i)
185200
end
186201
return a
187202
end

stdlib/LinearAlgebra/src/matmul.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, Ab
184184
end
185185
for d in ndims(Ay)+1:ndims(z)
186186
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
187-
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
187+
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
188188
axes(z), ", must have singleton at dim ", d)))
189189
end
190190
Ay .+ z
@@ -197,7 +197,7 @@ function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, Ab
197197
end
198198
for d in 3:ndims(z)
199199
# Similar error to (u*v) + z:
200-
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
200+
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
201201
axes(z), ", must have singleton at dim ", d)))
202202
end
203203
(u .* v) .+ z

stdlib/LinearAlgebra/src/structuredbroadcast.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ end
251251
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
252252
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
253253
sz = size(A)
254-
all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match"))
254+
for B in Bs
255+
size(B) == sz || Base.throw_promote_shape_mismatch(sz, size(B))
256+
end
255257
return f.(A, Bs...)
256258
end

stdlib/LinearAlgebra/test/structuredbroadcast.jl

+5
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ end
142142
@test map!(*, Z, X, Y) == broadcast(*, fX, fY)
143143
end
144144
end
145+
# these would be valid for broadcast, but not for map
146+
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)))
147+
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)), D)
148+
@test_throws DimensionMismatch map(+, D, D, Diagonal(rand(1)))
149+
@test_throws DimensionMismatch map(+, Diagonal(rand(1)), D, D)
145150
end
146151

147152
@testset "Issue #33397" begin

0 commit comments

Comments
 (0)