Skip to content

Commit

Permalink
Better blochmessiah (#381)
Browse files Browse the repository at this point in the history
* Nicer BM

* Updates changelog

* Unnecesary import

* More simplifications

* More simplifications

* one more

* Updates docstring

---------

Co-authored-by: Nicolas Quesada <[email protected]>
  • Loading branch information
nquesada and Nicolas Quesada authored Feb 1, 2024
1 parent 6247fc8 commit 8759363
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

* Further simplifies the implementation of `decompositions.williamson` and corrects its docstring [(#380)](https://github.com/XanaduAI/thewalrus/pull/380).

* Further simplifies the implementation of `decompositions.blochmessiah` [(#381)](https://github.com/XanaduAI/thewalrus/pull/381).


### Bug fixes

### Documentation
Expand Down
87 changes: 33 additions & 54 deletions thewalrus/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
williamson
symplectic_eigenvals
blochmessiah
takagi
Code details
------------
"""
import numpy as np

from scipy.linalg import block_diag, sqrtm, schur
from scipy.linalg import sqrtm, schur, polar
from thewalrus.symplectic import sympmat
from thewalrus.quantum.gaussian_checks import is_symplectic

Expand Down Expand Up @@ -75,26 +76,24 @@ def williamson(V, rtol=1e-05, atol=1e-08):

M12 = np.real_if_close(sqrtm(V))
Mm12 = np.linalg.inv(M12)
r1 = Mm12 @ omega @ Mm12
s1, K = schur(r1)
# In what follows a permutation matrix perm1 is constructed so that the Schur matrix has
Gamma = Mm12 @ omega @ Mm12
a, O = schur(Gamma)
# In what follows a permutation matrix perm is constructed so that the Schur matrix has
# only positive elements above the diagonal
# Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus a permutation perm2 is used
# Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus the permutation perm is updated
# to go to the ordering x_1, ..., x_n, p_1, ... , p_n
perm1 = np.arange(2 * n)
perm = np.arange(2 * n)
for i in range(n):
if s1[2 * i, 2 * i + 1] <= 0:
(perm1[2 * i], perm1[2 * i + 1]) = (perm1[2 * i + 1], perm1[2 * i])
if a[2 * i, 2 * i + 1] <= 0:
(perm[2 * i], perm[2 * i + 1]) = (perm[2 * i + 1], perm[2 * i])

perm2 = np.array([perm1[2 * i] for i in range(n)] + [perm1[2 * i + 1] for i in range(n)])
perm = np.array([perm[2 * i] for i in range(n)] + [perm[2 * i + 1] for i in range(n)])

Ktt = K[:, perm2]
s1t = s1[:, perm1][perm1]

dd = np.array([1 / s1t[2 * i, 2 * i + 1] for i in range(n)])
dd = np.concatenate([dd, dd])
O = O[:, perm]
phi = np.abs(np.diag(a, k=1)[::2])
dd = np.concatenate([1 / phi, 1 / phi])
ddsqrt = 1 / np.sqrt(dd)
S = M12 @ Ktt * ddsqrt
S = M12 @ O * ddsqrt
return np.diag(dd), S


Expand All @@ -107,62 +106,42 @@ def symplectic_eigenvals(cov):
Returns:
(array): symplectic eigenvalues
"""
M = int(len(cov) / 2)
M = len(cov) // 2
Omega = sympmat(M)
return np.real_if_close(-1j * np.linalg.eigvals(Omega @ cov))[::2]


def blochmessiah(S):
"""Returns the Bloch-Messiah decomposition of a symplectic matrix S = uff @ dff @ vff
where uff and vff are orthogonal symplectic matrices and dff is a diagonal matrix
"""Returns the Bloch-Messiah decomposition of a symplectic matrix S = O @ D @ Q
where O and Q are orthogonal symplectic matrices and D is a positive-definite diagonal matrix
of the form diag(d1,d2,...,dn,d1^-1, d2^-1,...,dn^-1),
Args:
S (array[float]): 2N x 2N real symplectic matrix
Returns:
tuple(array[float], : orthogonal symplectic matrix uff
array[float], : diagonal matrix dff
array[float]) : orthogonal symplectic matrix vff
tuple(array[float], : orthogonal symplectic matrix O
array[float], : diagonal matrix D
array[float]) : orthogonal symplectic matrix Q
"""

N, _ = S.shape

if not is_symplectic(S):
raise ValueError("Input matrix is not symplectic.")

# Changing Basis
R = (1 / np.sqrt(2)) * np.block(
[[np.eye(N // 2), 1j * np.eye(N // 2)], [np.eye(N // 2), -1j * np.eye(N // 2)]]
)
Sc = R @ S @ np.conjugate(R).T
# Polar Decomposition
u1, d1, v1 = np.linalg.svd(Sc)
Sig = u1 @ np.diag(d1) @ np.conjugate(u1).T
Unitary = u1 @ v1
# Blocks of Unitary and Hermitian symplectics
alpha = Unitary[0 : N // 2, 0 : N // 2]
beta = Sig[0 : N // 2, N // 2 : N]
# Bloch-Messiah in this Basis
d2, takagibeta = takagi(beta)
sval = np.arcsinh(d2)
uf = block_diag(takagibeta, takagibeta.conj())
blc = np.conjugate(takagibeta).T @ alpha
vf = block_diag(blc, blc.conj())
df = np.block(
[
[np.diag(np.cosh(sval)), np.diag(np.sinh(sval))],
[np.diag(np.sinh(sval)), np.diag(np.cosh(sval))],
]
)
# Rotating Back to Original Basis
uff = np.conjugate(R).T @ uf @ R
vff = np.conjugate(R).T @ vf @ R
dff = np.conjugate(R).T @ df @ R
dff = np.real_if_close(dff)
vff = np.real_if_close(vff)
uff = np.real_if_close(uff)
return uff, dff, vff
N = N // 2
V, P = polar(S, side="left")
A = P[:N, :N]
B = P[:N, N:]
C = P[N:, N:]
M = A - C + 1j * (B + B.T)
Lam, W = takagi(M)
Lam = 0.5 * Lam
O = np.block([[W.real, -W.imag], [W.imag, W.real]])
Q = O.T @ V
sqrt1pLam2 = np.sqrt(1 + Lam**2)
D = np.diag(np.concatenate([sqrt1pLam2 + Lam, sqrt1pLam2 - Lam]))
return O, D, Q


def takagi(A, svd_order=True):
Expand Down

0 comments on commit 8759363

Please sign in to comment.