From 5202215504c8bfd12f0d14527354be7f3b1e745f Mon Sep 17 00:00:00 2001 From: Hylke Donker Date: Sat, 1 Mar 2025 15:46:58 +0100 Subject: [PATCH] Add support for inhomogeneous parameters --- dynamax/linear_gaussian_ssm/models.py | 138 ++++++++++++++++----- dynamax/linear_gaussian_ssm/models_test.py | 55 +++++++- 2 files changed, 159 insertions(+), 34 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 3e8b595b..a62f9c09 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -7,12 +7,12 @@ from fastprogress.fastprogress import progress_bar from functools import partial -from jax import jit +from jax import jit, tree, vmap from jax.tree_util import tree_map from jaxtyping import Array, Float from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from typing import Any, Optional, Tuple, Union, runtime_checkable -from typing_extensions import Protocol +from typing_extensions import Protocol from dynamax.ssm import SSM from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample @@ -206,7 +206,7 @@ def sample(self, key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \ - -> Tuple[Float[Array, "num_timesteps state_dim"], + -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: """Sample from the model. @@ -588,6 +588,47 @@ def m_step(self, ) return params, m_step_state + def _check_params(self, params: ParamsLGSSM, num_timesteps: int) -> ParamsLGSSM: + """Replace None parameters with zeros.""" + dynamics, emissions = params.dynamics, params.emissions + is_inhomogeneous = dynamics.weights.ndim == 3 + + def _zeros_if_none(x, shape): + if x is None: + return jnp.zeros(shape) + return x + + shape_prefix = () + if is_inhomogeneous: + shape_prefix = (num_timesteps - 1,) + + clean_dynamics = ParamsLGSSMDynamics( + weights=dynamics.weights, + bias=_zeros_if_none(dynamics.bias, shape=shape_prefix + (self.state_dim,)), + input_weights=_zeros_if_none( + dynamics.input_weights, shape=shape_prefix + (self.state_dim, self.input_dim) + ), + cov=dynamics.cov + ) + shape_prefix = () + if is_inhomogeneous: + shape_prefix = (num_timesteps,) + + clean_emissions = ParamsLGSSMEmissions( + weights=emissions.weights, + bias=_zeros_if_none(emissions.bias, shape=shape_prefix + (self.emission_dim,)), + input_weights=_zeros_if_none( + emissions.input_weights, shape=shape_prefix + (self.emission_dim, self.input_dim) + ), + cov=emissions.cov + ) + return ParamsLGSSM( + initial=params.initial, + dynamics=clean_dynamics, + emissions=clean_emissions, + ) + + def fit_blocked_gibbs(self, key: PRNGKeyT, initial_params: ParamsLGSSM, @@ -599,7 +640,8 @@ def fit_blocked_gibbs(self, Args: key: random number key. - initial_params: starting parameters. + initial_params: starting parameters. Include a leading time axis for + the dynamics and emissions parameters in inhomogeneous models. sample_size: how many samples to draw. emissions: set of observation sequences. inputs: optional set of input sequences. @@ -609,67 +651,97 @@ def fit_blocked_gibbs(self, """ num_timesteps = len(emissions) + # Inhomogeneous models have a leading time dimension. + is_inhomogeneous = initial_params.dynamics.weights.ndim == 3 + if inputs is None: inputs = jnp.zeros((num_timesteps, 0)) + initial_params = self._check_params(initial_params, num_timesteps) + def sufficient_stats_from_sample(states): """Convert samples of states to sufficient statistics.""" inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1) # Let xn[t] = x[t+1] for t = 0...T-2 - x, xp, xn = states, states[:-1], states[1:] - u, up = inputs_joint, inputs_joint[:-1] + x, xn = states, states[1:] + u = inputs_joint + # Let z[t] = [x[t], u[t]] for t = 0...T-1 + z = jnp.concatenate([x, u], axis=-1) + # Let zp[t] = [x[t], u[t]] for t = 0...T-2 + zp = z[:-1] y = emissions init_stats = (x[0], jnp.outer(x[0], x[0]), 1) # Quantities for the dynamics distribution - # Let zp[t] = [x[t], u[t]] for t = 0...T-2 - sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]]) - sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]]) - sum_xnxnT = xn.T @ xn - dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) + sum_zpzpT = jnp.einsum('ti,tj->tij', zp, zp) + sum_zpxnT = jnp.einsum('ti,tj->tij', zp, xn) + sum_xnxnT = jnp.einsum('ti,tj->tij', xn, xn) + z_is_observed = jnp.ones(num_timesteps - 1) + # The dynamics stats have a leading time dimension. + dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, z_is_observed) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:, :-1, :-1], sum_zpxnT[:, :-1, :], sum_xnxnT, + z_is_observed) # Quantities for the emissions - # Let z[t] = [x[t], u[t]] for t = 0...T-1 - sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]]) - sum_zyT = jnp.block([[x.T @ y], [u.T @ y]]) - sum_yyT = y.T @ y - emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps) + sum_zzT = jnp.einsum('ti,tj->tij', z, z) + sum_zyT = jnp.einsum('ti,tj->tij', z, y) + sum_yyT = jnp.einsum('ti,tj->tij', y, y) + y_is_observed = jnp.ones(num_timesteps) + # The emissions stats have a leading time dimension. + emission_stats = (sum_zzT, sum_zyT, sum_yyT, y_is_observed) if not self.has_emissions_bias: - emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps) + emission_stats = (sum_zzT[:, :-1, :-1], sum_zyT[:, :-1, :], sum_yyT, y_is_observed) return init_stats, dynamics_stats, emission_stats - def lgssm_params_sample(rng, stats): - """Sample parameters of the model given sufficient statistics from observed states and emissions.""" - init_stats, dynamics_stats, emission_stats = stats - rngs = iter(jr.split(rng, 3)) - - # Sample the initial params + def _sample_initial_params(rng, init_stats): initial_posterior = niw_posterior_update(self.initial_prior, init_stats) - S, m = initial_posterior.sample(seed=next(rngs)) + S, m = initial_posterior.sample(seed=rng) + return ParamsLGSSMInitial(mean=m, cov=S) - # Sample the dynamics params + def _sample_dynamics_params(rng, dynamics_stats): dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) - Q, FB = dynamics_posterior.sample(seed=next(rngs)) + Q, FB = dynamics_posterior.sample(seed=rng) F = FB[:, :self.state_dim] B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + return ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q) - # Sample the emission params + def _sample_emission_params(rng, emission_stats): emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) - R, HD = emission_posterior.sample(seed=next(rngs)) + R, HD = emission_posterior.sample(seed=rng) H = HD[:, :self.state_dim] D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + return ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + + def lgssm_params_sample(rng, stats): + """Sample parameters of the model given sufficient statistics from observed states and emissions.""" + init_stats, dynamics_stats, emission_stats = stats + rngs = iter(jr.split(rng, 3)) + + # Sample the initial params + initial_params = _sample_initial_params(next(rngs), init_stats) + + # Sample the dynamics and emission params. + if not is_inhomogeneous: + # Aggregate summary statistics across time for homogeneous model. + dynamics_stats = tree.map(lambda x: jnp.sum(x, axis=0), dynamics_stats) + emission_stats = tree.map(lambda x: jnp.sum(x, axis=0), emission_stats) + dynamics_params = _sample_dynamics_params(next(rngs), dynamics_stats) + emission_params = _sample_emission_params(next(rngs), emission_stats) + else: + keys_dynamics = jr.split(next(rngs), num_timesteps - 1) + keys_emission = jr.split(next(rngs), num_timesteps) + dynamics_params = vmap(_sample_dynamics_params)(keys_dynamics, dynamics_stats) + emission_params = vmap(_sample_emission_params)(keys_emission, emission_stats) params = ParamsLGSSM( - initial=ParamsLGSSMInitial(mean=m, cov=S), - dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + initial=initial_params, + dynamics=dynamics_params, + emissions=emission_params, ) return params diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index 50b5aff8..2c4069c9 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -1,12 +1,16 @@ """ Tests for the linear Gaussian SSM models. """ +from itertools import count, product -import pytest +import jax.numpy as jnp import jax.random as jr +from jax import tree +import pytest from dynamax.linear_gaussian_ssm import LinearGaussianSSM from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM +from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM from dynamax.utils.utils import monotonically_increasing NUM_TIMESTEPS = 100 @@ -29,3 +33,52 @@ def test_sample_and_fit(cls, kwargs, inputs): fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3) assert monotonically_increasing(lps) fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3) + +@pytest.mark.parametrize(["has_dynamics_bias", "has_emissions_bias"], product([True, False], repeat=2)) +def test_inhomogeneous_lgcssm(has_dynamics_bias, has_emissions_bias): + """ + Test a LinearGaussianConjugateSSM with time-varying dynamics and emission model. + """ + state_dim = 2 + emission_dim = 3 + num_timesteps = 4 + keys = map(jr.PRNGKey, count()) + kwargs = { + "state_dim": state_dim, + "emission_dim": emission_dim, + "has_dynamics_bias": has_dynamics_bias, + "has_emissions_bias": has_emissions_bias, + } + model = LinearGaussianConjugateSSM(**kwargs) + params, param_props = model.initialize(jr.PRNGKey(0)) + # Repeat the parameters for each timestep. + inhomogeneous_dynamics = tree.map( + lambda x: jnp.repeat(x[None], num_timesteps - 1, axis=0), params.dynamics, + ) + inhomogeneous_emissions = tree.map( + lambda x: jnp.repeat(x[None], num_timesteps, axis=0), params.emissions, + ) + + _, emissions = model.sample(params, next(keys), num_timesteps=num_timesteps) + inhomogeneous_params = ParamsLGSSM( + initial=params.initial, + dynamics=inhomogeneous_dynamics, + emissions=inhomogeneous_emissions, + ) + params_trace = model.fit_blocked_gibbs( + next(keys), + inhomogeneous_params, + sample_size=5, + emissions=emissions, + ) + + # Arbitrarily check the last set of parameters from the Markov chain. + last_params = tree.map(lambda x: x[-1], params_trace) + assert last_params.initial.mean.shape == (state_dim,) + assert last_params.initial.cov.shape == (state_dim, state_dim) + assert last_params.dynamics.weights.shape == (num_timesteps - 1, state_dim, state_dim) + assert last_params.emissions.weights.shape == (num_timesteps, emission_dim, state_dim) + assert last_params.dynamics.bias.shape == (num_timesteps - 1, state_dim) + assert last_params.emissions.bias.shape == (num_timesteps, emission_dim) + assert last_params.dynamics.cov.shape == (num_timesteps - 1, state_dim, state_dim) + assert last_params.emissions.cov.shape == (num_timesteps, emission_dim, emission_dim)