Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inhomogeneous parameters in LinearGaussianConjugateSSM.fit_blocked_gibbs #403

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 105 additions & 33 deletions dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from fastprogress.fastprogress import progress_bar
from functools import partial
from jax import jit
from jax import jit, tree, vmap
from jax.tree_util import tree_map
from jaxtyping import Array, Float
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from typing import Any, Optional, Tuple, Union, runtime_checkable
from typing_extensions import Protocol
from typing_extensions import Protocol

from dynamax.ssm import SSM
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
Expand Down Expand Up @@ -206,7 +206,7 @@ def sample(self,
key: PRNGKeyT,
num_timesteps: int,
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \
-> Tuple[Float[Array, "num_timesteps state_dim"],
-> Tuple[Float[Array, "num_timesteps state_dim"],
Float[Array, "num_timesteps emission_dim"]]:
"""Sample from the model.

Expand Down Expand Up @@ -588,6 +588,47 @@ def m_step(self,
)
return params, m_step_state

def _check_params(self, params: ParamsLGSSM, num_timesteps: int) -> ParamsLGSSM:
"""Replace None parameters with zeros."""
dynamics, emissions = params.dynamics, params.emissions
is_inhomogeneous = dynamics.weights.ndim == 3

def _zeros_if_none(x, shape):
if x is None:
return jnp.zeros(shape)
return x

shape_prefix = ()
if is_inhomogeneous:
shape_prefix = (num_timesteps - 1,)

clean_dynamics = ParamsLGSSMDynamics(
weights=dynamics.weights,
bias=_zeros_if_none(dynamics.bias, shape=shape_prefix + (self.state_dim,)),
input_weights=_zeros_if_none(
dynamics.input_weights, shape=shape_prefix + (self.state_dim, self.input_dim)
),
cov=dynamics.cov
)
shape_prefix = ()
if is_inhomogeneous:
shape_prefix = (num_timesteps,)

clean_emissions = ParamsLGSSMEmissions(
weights=emissions.weights,
bias=_zeros_if_none(emissions.bias, shape=shape_prefix + (self.emission_dim,)),
input_weights=_zeros_if_none(
emissions.input_weights, shape=shape_prefix + (self.emission_dim, self.input_dim)
),
cov=emissions.cov
)
return ParamsLGSSM(
initial=params.initial,
dynamics=clean_dynamics,
emissions=clean_emissions,
)


def fit_blocked_gibbs(self,
key: PRNGKeyT,
initial_params: ParamsLGSSM,
Expand All @@ -599,7 +640,8 @@ def fit_blocked_gibbs(self,

Args:
key: random number key.
initial_params: starting parameters.
initial_params: starting parameters. Include a leading time axis for
the dynamics and emissions parameters in inhomogeneous models.
sample_size: how many samples to draw.
emissions: set of observation sequences.
inputs: optional set of input sequences.
Expand All @@ -609,67 +651,97 @@ def fit_blocked_gibbs(self,
"""
num_timesteps = len(emissions)

# Inhomogeneous models have a leading time dimension.
is_inhomogeneous = initial_params.dynamics.weights.ndim == 3

if inputs is None:
inputs = jnp.zeros((num_timesteps, 0))

initial_params = self._check_params(initial_params, num_timesteps)

def sufficient_stats_from_sample(states):
"""Convert samples of states to sufficient statistics."""
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
# Let xn[t] = x[t+1] for t = 0...T-2
x, xp, xn = states, states[:-1], states[1:]
u, up = inputs_joint, inputs_joint[:-1]
x, xn = states, states[1:]
u = inputs_joint
# Let z[t] = [x[t], u[t]] for t = 0...T-1
z = jnp.concatenate([x, u], axis=-1)
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
zp = z[:-1]
y = emissions

init_stats = (x[0], jnp.outer(x[0], x[0]), 1)

# Quantities for the dynamics distribution
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]])
sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]])
sum_xnxnT = xn.T @ xn
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
sum_zpzpT = jnp.einsum('ti,tj->tij', zp, zp)
sum_zpxnT = jnp.einsum('ti,tj->tij', zp, xn)
sum_xnxnT = jnp.einsum('ti,tj->tij', xn, xn)
z_is_observed = jnp.ones(num_timesteps - 1)
# The dynamics stats have a leading time dimension.
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, z_is_observed)
if not self.has_dynamics_bias:
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
num_timesteps - 1)
dynamics_stats = (sum_zpzpT[:, :-1, :-1], sum_zpxnT[:, :-1, :], sum_xnxnT,
z_is_observed)

# Quantities for the emissions
# Let z[t] = [x[t], u[t]] for t = 0...T-1
sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]])
sum_zyT = jnp.block([[x.T @ y], [u.T @ y]])
sum_yyT = y.T @ y
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
sum_zzT = jnp.einsum('ti,tj->tij', z, z)
sum_zyT = jnp.einsum('ti,tj->tij', z, y)
sum_yyT = jnp.einsum('ti,tj->tij', y, y)
y_is_observed = jnp.ones(num_timesteps)
# The emissions stats have a leading time dimension.
emission_stats = (sum_zzT, sum_zyT, sum_yyT, y_is_observed)
if not self.has_emissions_bias:
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
emission_stats = (sum_zzT[:, :-1, :-1], sum_zyT[:, :-1, :], sum_yyT, y_is_observed)

return init_stats, dynamics_stats, emission_stats

def lgssm_params_sample(rng, stats):
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
init_stats, dynamics_stats, emission_stats = stats
rngs = iter(jr.split(rng, 3))

# Sample the initial params
def _sample_initial_params(rng, init_stats):
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
S, m = initial_posterior.sample(seed=next(rngs))
S, m = initial_posterior.sample(seed=rng)
return ParamsLGSSMInitial(mean=m, cov=S)

# Sample the dynamics params
def _sample_dynamics_params(rng, dynamics_stats):
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
Q, FB = dynamics_posterior.sample(seed=next(rngs))
Q, FB = dynamics_posterior.sample(seed=rng)
F = FB[:, :self.state_dim]
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
return ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q)

# Sample the emission params
def _sample_emission_params(rng, emission_stats):
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
R, HD = emission_posterior.sample(seed=next(rngs))
R, HD = emission_posterior.sample(seed=rng)
H = HD[:, :self.state_dim]
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
return ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)

def lgssm_params_sample(rng, stats):
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
init_stats, dynamics_stats, emission_stats = stats
rngs = iter(jr.split(rng, 3))

# Sample the initial params
initial_params = _sample_initial_params(next(rngs), init_stats)

# Sample the dynamics and emission params.
if not is_inhomogeneous:
# Aggregate summary statistics across time for homogeneous model.
dynamics_stats = tree.map(lambda x: jnp.sum(x, axis=0), dynamics_stats)
emission_stats = tree.map(lambda x: jnp.sum(x, axis=0), emission_stats)
dynamics_params = _sample_dynamics_params(next(rngs), dynamics_stats)
emission_params = _sample_emission_params(next(rngs), emission_stats)
else:
keys_dynamics = jr.split(next(rngs), num_timesteps - 1)
keys_emission = jr.split(next(rngs), num_timesteps)
dynamics_params = vmap(_sample_dynamics_params)(keys_dynamics, dynamics_stats)
emission_params = vmap(_sample_emission_params)(keys_emission, emission_stats)

params = ParamsLGSSM(
initial=ParamsLGSSMInitial(mean=m, cov=S),
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
initial=initial_params,
dynamics=dynamics_params,
emissions=emission_params,
)
return params

Expand Down
55 changes: 54 additions & 1 deletion dynamax/linear_gaussian_ssm/models_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""
Tests for the linear Gaussian SSM models.
"""
from itertools import count, product

import pytest
import jax.numpy as jnp
import jax.random as jr
from jax import tree
import pytest

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM
from dynamax.utils.utils import monotonically_increasing

NUM_TIMESTEPS = 100
Expand All @@ -29,3 +33,52 @@ def test_sample_and_fit(cls, kwargs, inputs):
fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3)
assert monotonically_increasing(lps)
fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3)

@pytest.mark.parametrize(["has_dynamics_bias", "has_emissions_bias"], product([True, False], repeat=2))
def test_inhomogeneous_lgcssm(has_dynamics_bias, has_emissions_bias):
"""
Test a LinearGaussianConjugateSSM with time-varying dynamics and emission model.
"""
state_dim = 2
emission_dim = 3
num_timesteps = 4
keys = map(jr.PRNGKey, count())
kwargs = {
"state_dim": state_dim,
"emission_dim": emission_dim,
"has_dynamics_bias": has_dynamics_bias,
"has_emissions_bias": has_emissions_bias,
}
model = LinearGaussianConjugateSSM(**kwargs)
params, param_props = model.initialize(jr.PRNGKey(0))
# Repeat the parameters for each timestep.
inhomogeneous_dynamics = tree.map(
lambda x: jnp.repeat(x[None], num_timesteps - 1, axis=0), params.dynamics,
)
inhomogeneous_emissions = tree.map(
lambda x: jnp.repeat(x[None], num_timesteps, axis=0), params.emissions,
)

_, emissions = model.sample(params, next(keys), num_timesteps=num_timesteps)
inhomogeneous_params = ParamsLGSSM(
initial=params.initial,
dynamics=inhomogeneous_dynamics,
emissions=inhomogeneous_emissions,
)
params_trace = model.fit_blocked_gibbs(
next(keys),
inhomogeneous_params,
sample_size=5,
emissions=emissions,
)

# Arbitrarily check the last set of parameters from the Markov chain.
last_params = tree.map(lambda x: x[-1], params_trace)
assert last_params.initial.mean.shape == (state_dim,)
assert last_params.initial.cov.shape == (state_dim, state_dim)
assert last_params.dynamics.weights.shape == (num_timesteps - 1, state_dim, state_dim)
assert last_params.emissions.weights.shape == (num_timesteps, emission_dim, state_dim)
assert last_params.dynamics.bias.shape == (num_timesteps - 1, state_dim)
assert last_params.emissions.bias.shape == (num_timesteps, emission_dim)
assert last_params.dynamics.cov.shape == (num_timesteps - 1, state_dim, state_dim)
assert last_params.emissions.cov.shape == (num_timesteps, emission_dim, emission_dim)