Skip to content

Commit d07a863

Browse files
dkarraschKristofferC
authored and
KristofferC
committed
Accomodate for rectangular matrices in copytrito! (#54587)
(cherry picked from commit fc54be6)
1 parent 0653044 commit d07a863

File tree

4 files changed

+101
-20
lines changed

4 files changed

+101
-20
lines changed

stdlib/LinearAlgebra/src/generic.jl

+15-10
Original file line numberDiff line numberDiff line change
@@ -1934,19 +1934,24 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
19341934
BLAS.chkuplo(uplo)
19351935
m,n = size(A)
19361936
m1,n1 = size(B)
1937-
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
19381937
A = Base.unalias(B, A)
19391938
if uplo == 'U'
1940-
for j=1:n
1941-
for i=1:min(j,m)
1942-
@inbounds B[i,j] = A[i,j]
1943-
end
1939+
if n < m
1940+
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
1941+
else
1942+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
19441943
end
1945-
else # uplo == 'L'
1946-
for j=1:n
1947-
for i=j:m
1948-
@inbounds B[i,j] = A[i,j]
1949-
end
1944+
for j in 1:n, i in 1:min(j,m)
1945+
@inbounds B[i,j] = A[i,j]
1946+
end
1947+
else # uplo == 'L'
1948+
if m < n
1949+
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
1950+
else
1951+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
1952+
end
1953+
for j in 1:n, i in j:m
1954+
@inbounds B[i,j] = A[i,j]
19501955
end
19511956
end
19521957
return B

stdlib/LinearAlgebra/src/lapack.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -7163,9 +7163,23 @@ for (fn, elty) in ((:dlacpy_, :Float64),
71637163
function lacpy!(B::AbstractMatrix{$elty}, A::AbstractMatrix{$elty}, uplo::AbstractChar)
71647164
require_one_based_indexing(A, B)
71657165
chkstride1(A, B)
7166-
m,n = size(A)
7167-
m1,n1 = size(B)
7168-
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
7166+
m, n = size(A)
7167+
m1, n1 = size(B)
7168+
if uplo == 'U'
7169+
if n < m
7170+
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
7171+
else
7172+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7173+
end
7174+
elseif uplo == 'L'
7175+
if m < n
7176+
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
7177+
else
7178+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7179+
end
7180+
else
7181+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7182+
end
71697183
lda = max(1, stride(A, 2))
71707184
ldb = max(1, stride(B, 2))
71717185
ccall((@blasfunc($fn), libblastrampoline), Cvoid,

stdlib/LinearAlgebra/test/generic.jl

+50-6
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,56 @@ end
647647

648648
@testset "copytrito!" begin
649649
n = 10
650-
A = rand(n, n)
651-
for uplo in ('L', 'U')
652-
B = zeros(n, n)
653-
copytrito!(B, A, uplo)
654-
C = uplo == 'L' ? tril(A) : triu(A)
655-
@test B C
650+
@testset "square" begin
651+
for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U')
652+
for AA in (A, view(A, reverse.(axes(A))...))
653+
C = uplo == 'L' ? tril(AA) : triu(AA)
654+
for B in (zeros(n, n), zeros(n+1, n+2))
655+
copytrito!(B, AA, uplo)
656+
@test view(B, 1:n, 1:n) == C
657+
end
658+
end
659+
end
660+
end
661+
@testset "wide" begin
662+
for A in (rand(n, 2n), rand(Int8, n, 2n))
663+
for AA in (A, view(A, reverse.(axes(A))...))
664+
C = tril(AA)
665+
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
666+
B = zeros(M, N)
667+
copytrito!(B, AA, 'L')
668+
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
669+
end
670+
@test_throws DimensionMismatch copytrito!(zeros(n-1, 2n), AA, 'L')
671+
C = triu(AA)
672+
for (M, N) in ((n, 2n), (n+1, 2n), (n, 2n+1), (n+1, 2n+1))
673+
B = zeros(M, N)
674+
copytrito!(B, AA, 'U')
675+
@test view(B, 1:n, 1:2n) == view(C, 1:n, 1:2n)
676+
end
677+
@test_throws DimensionMismatch copytrito!(zeros(n+1, 2n-1), AA, 'U')
678+
end
679+
end
680+
end
681+
@testset "tall" begin
682+
for A in (rand(2n, n), rand(Int8, 2n, n))
683+
for AA in (A, view(A, reverse.(axes(A))...))
684+
C = triu(AA)
685+
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
686+
B = zeros(M, N)
687+
copytrito!(B, AA, 'U')
688+
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
689+
end
690+
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'U')
691+
C = tril(AA)
692+
for (M, N) in ((2n, n), (2n, n+1), (2n+1, n), (2n+1, n+1))
693+
B = zeros(M, N)
694+
copytrito!(B, AA, 'L')
695+
@test view(B, 1:2n, 1:n) == view(C, 1:2n, 1:n)
696+
end
697+
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'L')
698+
end
699+
end
656700
end
657701
@testset "aliasing" begin
658702
M = Matrix(reshape(1:36, 6, 6))

stdlib/LinearAlgebra/test/lapack.jl

+19-1
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,26 @@ end
805805
B = zeros(elty, n, n)
806806
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
807807
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
808-
@test B C
808+
@test B == C
809+
B = zeros(elty, n+1, n+1)
810+
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
811+
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
812+
@test view(B, 1:n, 1:n) == C
809813
end
814+
A = rand(elty, n, n+1)
815+
B = zeros(elty, n, n)
816+
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
817+
@test B == view(tril(A), 1:n, 1:n)
818+
B = zeros(elty, n, n+1)
819+
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
820+
@test B == triu(A)
821+
A = rand(elty, n+1, n)
822+
B = zeros(elty, n, n)
823+
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
824+
@test B == view(triu(A), 1:n, 1:n)
825+
B = zeros(elty, n+1, n)
826+
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
827+
@test B == tril(A)
810828
end
811829
end
812830

0 commit comments

Comments
 (0)