Skip to content

Commit

Permalink
Fix several issues with empty matrices
Browse files Browse the repository at this point in the history
* make det of null matrix return 1, see e.g.
  https://en.wikipedia.org/wiki/Matrix_(mathematics)#Empty_matrix
* fix str/repr

Closes mpmath#745

Co-authored-by: Sergey B Kirpichev <[email protected]>
  • Loading branch information
mart-mihkel and skirpichev committed Feb 24, 2024
1 parent 529daf5 commit 49e16e0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
8 changes: 7 additions & 1 deletion mpmath/matrices/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def LU_decomp(ctx, A, overwrite=False, use_cache=True):
A[i,j] /= A[j,j]
for k in range(j + 1, n):
A[i,k] -= A[i,j]*A[j,k]
if ctx.absmin(A[n - 1,n - 1]) <= tol:
if p and ctx.absmin(A[n - 1,n - 1]) <= tol:
raise ZeroDivisionError('matrix is numerically singular')
# cache decomposition
if not overwrite and isinstance(orig, ctx.matrix):
Expand Down Expand Up @@ -543,6 +543,12 @@ def det(ctx, A):
>>> print(det(A))
1.0
The determinant of a 0 by 0 matrix is 1 as the product of no factors
is by convention the multiplicative identity.
>>> A = matrix(0, 0)
>>> print(det(A))
1
But in general a matrix can have any number as its determinant.
>>> A = matrix([[2, 6, 4],[3, 8, 6],[1, 1, 2]])
Expand Down
23 changes: 14 additions & 9 deletions mpmath/matrices/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def __init__(self, *args, **kwargs):
" If you want to truncate values to integer, use .apply(int) instead.",
DeprecationWarning)
if isinstance(args[0], (list, tuple)):
if isinstance(args[0][0], (list, tuple)):
if not args[0]:
self.__rows = 0
self.__cols = 0
elif isinstance(args[0][0], (list, tuple)):
# interpret nested list as matrix
A = args[0]
self.__rows = len(A)
Expand Down Expand Up @@ -358,12 +361,12 @@ def __nstr__(self, n=None, **kwargs):
# Pad each element up to maxlen so the columns line up
row[j] = elem.rjust(maxlen[j])
res[i] = "[" + colsep.join(row) + "]"
return rowsep.join(res)
return rowsep.join(res) if self.rows or self.cols else ''

def __str__(self):
return self.__nstr__()

def _toliststr(self, avoid_type=False):
def _toliststr(self):
"""
Create a list string from a matrix.
Expand All @@ -375,14 +378,16 @@ def _toliststr(self, avoid_type=False):
for i in range(self.__rows):
s += '['
for j in range(self.__cols):
if not avoid_type or not isinstance(self[i,j], typ):
if not isinstance(self[i,j], typ):
a = repr(self[i,j])
else:
a = "'" + str(self[i,j]) + "'"
s += a + ', '
s = s[:-2]
if s[-1] != '[':
s = s[:-2]
s += '],\n '
s = s[:-3]
if s[-1] != '[':
s = s[:-3]
s += ']'
return s

Expand All @@ -396,7 +401,7 @@ def __repr__(self):
if self.ctx.pretty:
return self.__str__()
s = 'matrix(\n'
s += self._toliststr(avoid_type=True) + ')'
s += self._toliststr() + ')'
return s

def __get_element(self, key):
Expand Down Expand Up @@ -989,8 +994,8 @@ def mnorm(ctx, A, p=1):
p = ctx.convert(p)
m, n = A.rows, A.cols
if p == 1:
return max(ctx.fsum((A[i,j] for i in range(m)), absolute=1) for j in range(n))
return max((ctx.fsum((A[i,j] for i in range(m)), absolute=1) for j in range(n)), default=0)
elif p == ctx.inf:
return max(ctx.fsum((A[i,j] for j in range(n)), absolute=1) for i in range(m))
return max((ctx.fsum((A[i,j] for j in range(n)), absolute=1) for i in range(m)), default=0)
else:
raise NotImplementedError("matrix p-norm for arbitrary p")
3 changes: 3 additions & 0 deletions mpmath/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
[3, 8, 6],
[1, 1, 2]])

A14 = matrix(0, 0)

def test_LU_decomp():
A = A3.copy()
b = b3
Expand Down Expand Up @@ -213,6 +215,7 @@ def test_det():
assert det(zeros(3)) == 0
assert det(A11) == 0
assert absmin(det(A12*1e-30) - 1e-30) < eps
assert det(A14) == 1

def test_cond():
A = matrix([[1.2969, 0.8648], [0.2161, 0.1441]])
Expand Down
34 changes: 34 additions & 0 deletions mpmath/tests/test_str.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from mpmath import inf, matrix, nstr


A1 = matrix([])
A2 = matrix([[]])
A3 = matrix(2)
A4 = matrix([1, 2, 3])


def test_nstr():
m = matrix([[0.75, 0.190940654, -0.0299195971],
[0.190940654, 0.65625, 0.205663228],
Expand All @@ -13,3 +19,31 @@ def test_nstr():
'''[ 0.75 0.1909 -0.02992]
[ 0.1909 0.6563 0.2057]
[-0.02992 0.2057 6.445e-21]'''

def test_matrix_repr():
assert repr(A1) == \
'''matrix(
[])'''
assert repr(A2) == \
'''matrix(
[[]])'''
assert repr(A3) == \
'''matrix(
[['0.0', '0.0'],
['0.0', '0.0']])'''
assert repr(A4) == \
'''matrix(
[['1.0'],
['2.0'],
['3.0']])'''

def test_matrix_str():
assert str(A1) == ''
assert str(A2) == '[]'
assert str(A3) == \
'''[0.0 0.0]
[0.0 0.0]'''
assert str(A4) == \
'''[1.0]
[2.0]
[3.0]'''

0 comments on commit 49e16e0

Please sign in to comment.