Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting code + correct shape annotations #15

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
args: [--prose-wrap=always]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.6"
rev: "v0.1.7"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
Expand Down Expand Up @@ -84,7 +84,7 @@ repos:
- id: validate-pyproject

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.27.2
rev: 0.27.3
hooks:
- id: check-dependabot
- id: check-github-workflows
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ disallow_incomplete_defs = false
[[tool.mypy.overrides]]
module = [
"astropy.*",
"beartype.*",
"diffrax.*",
"equinox.*",
"hypothesis.*",
Expand Down
20 changes: 18 additions & 2 deletions src/galdynamix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Copyright (c) 2023 galdynamix maintainers. All rights reserved."""
# ruff:noqa: F401

__all__ = ["__version__"]
__all__ = [
"__version__",
# modules
"units",
"potential",
"integrator",
"dynamics",
"utils",
"typing",
]

import os

Expand All @@ -11,5 +21,11 @@

config.update("jax_enable_x64", True) # noqa: FBT003

typechecker: str | None
if os.environ.get("GALDYNAMIX_ENABLE_RUNTIME_TYPECHECKS", "1") == "1":
install_import_hook(["galdynamix"], "beartype.beartype")
typechecker = "beartype.beartype"
else:
typechecker = None

with install_import_hook("galdynamix", typechecker):
from galdynamix import dynamics, integrate, potential, typing, units, utils
193 changes: 108 additions & 85 deletions src/galdynamix/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,74 @@

__all__ = ["AbstractPhaseSpacePosition", "PhaseSpacePosition"]

from typing import Any, cast
from abc import abstractmethod
from typing import TYPE_CHECKING, Any

import equinox as eqx
import jax.numpy as xp
from jaxtyping import Array, Float

from galdynamix.typing import VectorN, VectorN3, VectorN6, VectorN7
from galdynamix.utils._jax import partial_jit
from galdynamix.utils.dataclasses import field
from galdynamix.typing import BatchFloatScalar, BatchVec3, BatchVec6, BatchVec7
from galdynamix.utils import partial_jit
from galdynamix.utils._shape import atleast_batched, batched_shape

if TYPE_CHECKING:
from galdynamix.potential._potential.base import AbstractPotentialBase

def convert_to_N3(x: Any) -> VectorN3:
"""Convert to a 3-vector."""

def converter_batchvec(x: Any) -> Float[Array, "*batch _"]:
"""Convert to a batched vector."""
out = xp.asarray(x)
if out.ndim == 1:
out = out[None, :]
if out.ndim == 0:
out = out[None]
return out # shape checking done by jaxtyping + beartype


class AbstractPhaseSpacePosition(eqx.Module): # type: ignore[misc]
class AbstractPhaseSpacePositionBase(eqx.Module): # type: ignore[misc]
"""Abstract Base Class of Phase-Space Positions.

Todo:
----
- Units stuff
- GR stuff
- GR stuff (note that then this will include time and can be merged with
``AbstractPhaseSpacePosition``)
"""

q: VectorN3 = field(converter=convert_to_N3)
"""Position of the stream particles (x, y, z) [kpc]."""

p: VectorN3 = field(converter=convert_to_N3)
"""Position of the stream particles (x, y, z) [kpc/Myr]."""
q: BatchVec3 = eqx.field(converter=converter_batchvec)
"""Positions (x, y, z)."""

t: VectorN
"""Array of times [Myr]."""
p: BatchVec3 = eqx.field(converter=converter_batchvec)
r"""Conjugate momenta (v_x, v_y, v_z)."""

@property
@partial_jit()
def qp(self) -> VectorN6:
"""Return as a single Array[(N, Q + P),]."""
# Determine output shape
qd = self.q.shape[1] # dimensionality of q
shape = (self.q.shape[0], qd + self.p.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qd].set(self.q)
out = out.at[:, qd:].set(self.p)
return out # noqa: RET504
@abstractmethod
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch, component shape."""
raise NotImplementedError

@property
@partial_jit()
def w(self) -> VectorN7:
"""Return as a single Array[(N, Q + P + T),]."""
qp = self.qp
qpd = qp.shape[1] # dimensionality of qp
# Reshape t to (N, 1) if necessary
t = self.t[:, None] if self.t.ndim == 1 else self.t
# Determine output shape
shape = (qp.shape[0], qpd + t.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qpd].set(qp)
out = out.at[:, qpd:].set(t)
return out # noqa: RET504
def shape(self) -> tuple[int, ...]:
"""Shape of the position and velocity arrays."""
batch_shape, component_shapes = self._shape_tuple
return (*batch_shape, sum(component_shapes))

# ==========================================================================
# Array stuff
# Convenience properties

@property
def shape(self) -> tuple[int, ...]:
"""Shape of the position and velocity arrays."""
return cast(
tuple[int, ...],
xp.broadcast_shapes(self.q.shape, self.p.shape, self.t.shape),
)
@partial_jit()
def qp(self) -> BatchVec6:
"""Return as a single Array[(*batch, Q + P),]."""
batch_shape, component_shapes = self._shape_tuple
q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1])
p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2])
return xp.concatenate((q, p), axis=-1)

# ==========================================================================
# Dynamical quantities

@partial_jit()
def kinetic_energy(self) -> VectorN:
def kinetic_energy(self) -> BatchFloatScalar:
r"""Return the specific kinetic energy.

.. math::
Expand All @@ -91,51 +78,45 @@ def kinetic_energy(self) -> VectorN:

Returns
-------
E : :class:`~astropy.units.Quantity`
E : Array[float, (*batch,)]
The kinetic energy.
"""
# TODO: use a ``norm`` function
# TODO: use a ``norm`` function so that this works for non-Cartesian.
return 0.5 * xp.sum(self.p**2, axis=-1)

@partial_jit() # TODO: annotate as AbstractPotentialBase
def potential_energy(self, potential: Any, /) -> VectorN:
r"""Return the specific potential energy.

.. math::
class AbstractPhaseSpacePosition(AbstractPhaseSpacePositionBase):
"""Abstract Base Class of Phase-Space Positions."""

E_\Phi = \Phi(\boldsymbol{q})
t: BatchFloatScalar = eqx.field(converter=converter_batchvec)
"""Array of times."""

Parameters
----------
potential : `galdynamix.potential.AbstractPotentialBase`
The potential object to compute the energy from.

Returns
-------
E : :class:`~jax.Array`
The specific potential energy.
"""
return potential.potential_energy(self, self.t)

@partial_jit() # TODO: annotate as AbstractPotentialBase
def energy(self, potential: Any, /) -> VectorN:
r"""Return the specific total energy.

.. math::
@property
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch ."""
qbatch, qshape = batched_shape(self.q, expect_scalar=False)
pbatch, pshape = batched_shape(self.p, expect_scalar=False)
tbatch, tshape = batched_shape(self.t, expect_scalar=True)
batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, (qshape, pshape, tshape)

E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
E_\Phi = \Phi(\boldsymbol{q})
E = E_K + E_\Phi
# ==========================================================================
# Convenience properties

Returns
-------
E : :class:`~astropy.units.Quantity`
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)
@property
@partial_jit()
def w(self) -> BatchVec7:
"""Return as a single Array[(*batch, Q + P + T),]."""
batch_shape, component_shapes = self._shape_tuple
q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1])
p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2])
t = xp.broadcast_to(
atleast_batched(self.t), batch_shape + component_shapes[2:3]
)
return xp.concatenate((q, p, t), axis=-1)

@partial_jit()
def angular_momentum(self) -> VectorN3:
def angular_momentum(self) -> BatchVec3:
r"""Compute the angular momentum.

.. math::
Expand All @@ -147,7 +128,7 @@ def angular_momentum(self) -> VectorN3:

Returns
-------
L : :class:`~astropy.units.Quantity`
L : Array[float, (*batch,3)]
Array of angular momentum vectors.

Examples
Expand All @@ -163,6 +144,48 @@ def angular_momentum(self) -> VectorN3:
# TODO: when q, p are not Cartesian.
return xp.cross(self.q, self.p)

# ==========================================================================
# Dynamical quantities

@partial_jit()
def potential_energy(
self, potential: "AbstractPotentialBase", /
) -> BatchFloatScalar:
r"""Return the specific potential energy.

.. math::

E_\Phi = \Phi(\boldsymbol{q})

Parameters
----------
potential : `galdynamix.potential.AbstractPotentialBase`
The potential object to compute the energy from.

Returns
-------
E : Array[float, (*batch,)]
The specific potential energy.
"""
return potential.potential_energy(self.q, t=self.t)

@partial_jit()
def energy(self, potential: "AbstractPotentialBase", /) -> BatchFloatScalar:
r"""Return the specific total energy.

.. math::

E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
E_\Phi = \Phi(\boldsymbol{q})
E = E_K + E_\Phi

Returns
-------
E : Array[float, (*batch,)]
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)


class PhaseSpacePosition(AbstractPhaseSpacePosition):
pass
10 changes: 5 additions & 5 deletions src/galdynamix/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import override

from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.typing import FloatScalar
from galdynamix.typing import BatchFloatScalar
from galdynamix.utils._jax import partial_jit

from ._core import AbstractPhaseSpacePosition
Expand All @@ -29,7 +29,7 @@ class Orbit(AbstractPhaseSpacePosition):
@partial_jit()
def potential_energy(
self, potential: AbstractPotentialBase | None = None, /
) -> FloatScalar:
) -> BatchFloatScalar:
r"""Return the specific potential energy.

.. math::
Expand All @@ -43,9 +43,9 @@ def potential_energy(

Returns
-------
E : :class:`~jax.Array`
E : Array[float, (*batch,)]
The specific potential energy.
"""
if potential is None:
return self.potential.potential_energy(self, self.t)
return potential.potential_energy(self, self.t)
return self.potential.potential_energy(self.q, t=self.t)
return potential.potential_energy(self.q, t=self.t)
Loading