Source code for state_moseq.hhmm_standard

import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import norm, dirichlet
from jaxtyping import Array, Float, Int
from jax.nn import softmax
import tqdm
from typing import Tuple
from functools import partial
from dynamax.hidden_markov_model import (
    hmm_filter,
    hmm_smoother,
    hmm_posterior_mode,
    hmm_posterior_sample,
    parallel_hmm_posterior_sample,
)

from .util import (
    simulate_markov_chain,
    lower_dim,
    raise_dim,
    count_transitions,
)

na = jnp.newaxis

"""
data = {
    "syllables": (n_sequences, n_timesteps, n_syllables),
    "mask": (n_sequences, n_timesteps),
}

states: (n_sequences, n_timesteps)

params = {
    "emissions": (n_states, n_syllables, n_syllables),
    "trans_probs": (n_states, n_states),
}

hypparams = {
    "n_states": (,),
    "emission_beta": (,),
    "trans_beta": (,),
    "trans_kappa": (,),
    "n_syllables"
}
"""


[docs] def obs_log_likelihoods( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Compute log likelihoods of observations for each hidden state.""" n_sequences = data["syllables"].shape[0] n_states = params["trans_probs"].shape[0] log_syllable_trans_probs = jnp.log(params["emissions"]) log_likelihoods = jax.vmap( lambda T: T[data["syllables"][:, :-1], data["syllables"][:, 1:]] )(log_syllable_trans_probs) log_likelihoods = jnp.concatenate( [jnp.zeros((n_states, n_sequences, 1)), log_likelihoods], axis=2 ) return log_likelihoods.transpose((1, 2, 0)) * data["mask"][:, :, na]
[docs] def log_params_prob( params: dict, hypparams: dict, ) -> Float: """Compute the log probability of the parameters based on their priors.""" n_states, n_syllables = params["emissions"].shape[:2] # prior on emission base parameters emissions_log_prob = jax.vmap(dirichlet.logpdf, in_axes=(0, None))( params["emissions"].reshape(-1, n_syllables), jnp.ones(n_syllables) * hypparams["emissions_beta"], ).sum() # prior on transition parameters trans_probs_log_prob = jax.vmap(dirichlet.logpdf)( params["trans_probs"], jnp.eye(n_states) * hypparams["trans_kappa"] + hypparams["trans_beta"], ).sum() return emissions_log_prob + trans_probs_log_prob
[docs] def log_joint_prob( data: dict, params: dict, hypparams: dict, ) -> Float: """Compute the log joint probability of the data and parameters.""" return marginal_loglik(data, params) + log_params_prob(params, hypparams)
[docs] @partial(jax.jit, static_argnums=(3,)) def resample_states( seed: Float[Array, "2"], data: dict, params: dict, parallel: bool = False, ) -> Tuple[Int[Array, "n_sequences n_timesteps"], Float]: """Resample hidden states from their posterior distribution. Args: seed: random seed data: data dictionary params: parameters dictionary parallel: whether to use parallel message passing Returns: states: resampled hidden states marginal_loglik: marginal log likelihood of the data """ n_states = params["trans_probs"].shape[0] seeds = jr.split(seed, data["syllables"].shape[0]) sample_fn = parallel_hmm_posterior_sample if parallel else hmm_posterior_sample marginal_logliks, states = jax.vmap(sample_fn, in_axes=(0, None, None, 0))( seeds, jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ) return states, marginal_logliks.sum()
[docs] def fit_gibbs( data: dict, hypparams: dict, init_params: dict, init_states: Int[Array, "n_sequences n_timesteps"] = None, seed: Float[Array, "2"] = jr.PRNGKey(0), num_iters: Int = 100, parallel: bool = False, ) -> Tuple[dict, Float[Array, "num_iters"]]: """Fit a model using Gibbs sampling. Args: data: data dictionary hypparams: hyperparameters dictionary init_params: initial parameters directionary init_states: initial hidden states (optional) seed: random seed num_iters: number of iterations parallel: whether to use parallel message passing Returns: params: fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration """ if init_states is None: states, _ = resample_states(seed, data, init_params, parallel) log_joints = [] params = init_params for _ in tqdm.trange(num_iters): seed, subseed = jr.split(seed) params = resample_params(subseed, data, params, states, hypparams) states, marginal_loglik = resample_states(seed, data, params, parallel) log_joints.append(marginal_loglik + log_params_prob(params, hypparams)) return params, states, jnp.array(log_joints)
[docs] def initialize_params( data: dict, hypparams: dict, states: Int[Array, "n_sequences n_timesteps"] = None, seed: Float[Array, "2"] = jr.PRNGKey(0), ) -> dict: """Initialize parameters by sampling from their prior distribution or using provided states. Args: data: data dictionary hypparams: hyperparameters dictionary states: states used for initializing the parameters (optional) seed: random seed """ if states is not None: params = resample_params(seed, data, states, hypparams) else: params = random_params(seed, hypparams) return params
[docs] def marginal_loglik( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginal log likelihood of the data""" n_states = params["trans_probs"].shape[0] mll = jax.vmap(hmm_filter, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).marginal_loglik.sum() return mll
[docs] def smoothed_states( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginals of hidden states using forward-backward algorithm.""" n_states = params["trans_probs"].shape[0] return jax.vmap(hmm_smoother, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).smoothed_probs
[docs] def filtered_states( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginals of hidden states using forward-backward algorithm.""" n_states = params["trans_probs"].shape[0] return jax.vmap(hmm_filter, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).filtered_probs
[docs] def predicted_states( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Predict hidden states using Viterbi algorithm.""" n_states = params["trans_probs"].shape[0] return jax.vmap(hmm_posterior_mode, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), )
[docs] def random_params( seed: Float[Array, "2"], hypparams: dict, ) -> dict: """Generate random model parameters. emissions ~ Dirichlet(emissions_beta) (for each state and syllable) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I) Args: seed: random seed hypparams: hyperparameters dictionary Returns: params: parameters dictionary """ n_syllables = hypparams["n_syllables"] n_states = hypparams["n_states"] seeds = jr.split(seed, 2) emissions = jr.dirichlet( seeds[0], jnp.ones((n_states, n_syllables, n_syllables)) * hypparams["emissions_beta"], ) trans_probs = jax.vmap(jr.dirichlet)( jr.split(seeds[2], n_states), jnp.eye(n_states) * hypparams["trans_kappa"] + hypparams["trans_beta"], ) return { "emissions": emissions, "trans_probs": trans_probs, }
[docs] def simulate( seed: Float[Array, "2"], params: dict, n_timesteps: Int, n_sequences: Int, ) -> Tuple[ Int[Array, "n_sequences n_timesteps"], Int[Array, "n_sequences n_timesteps"] ]: """Simulate data from the model. Args: seed: random seed params: parameters dictionary n_timesteps: number of timesteps to simulate n_sequences: number of sessions to simulate Returns: states: simulated states syllables: simulated syllables """ seeds = jr.split(seed, 3) states = jax.vmap(simulate_markov_chain, in_axes=(0, None, None))( jr.split(seeds[0], n_sequences), params["trans_probs"], n_timesteps, ) syllable_trans_probs = params["emissions"][states] syllables = jax.vmap(simulate_markov_chain, in_axes=(0, 0, None))( jr.split(seeds[1], n_sequences), syllable_trans_probs, n_timesteps ) return states, syllables
[docs] def resample_params( seed: Float[Array, "2"], data: dict, params: dict, states: Int[Array, "n_sequences n_timesteps"], hypparams: dict, ) -> dict: """Resample parameters from their posterior distribution. Args: seed: random seed data: data dictionary params: parameters dictionary states: hidden states hypparams: hyperparameters dictionary Returns: params: parameters dictionary """ seeds = jr.split(seed, 2) emissions = resample_emission_params( seeds[1], data["syllables"], data["mask"], states, hypparams["n_states"], hypparams["n_syllables"], hypparams["emissions_beta"], ) trans_probs = resample_trans_probs( seeds[2], data["mask"], states, hypparams["n_states"], hypparams["trans_beta"], hypparams["trans_kappa"], ) params = { "emissions": emissions, "trans_probs": trans_probs, } return params emissions = resample_emission_params( seeds[1], data["syllables"], data["mask"], states, hypparams["n_states"], hypparams["n_syllables"], hypparams["emissions_beta"], )
[docs] @partial(jax.jit, static_argnums=(4, 5)) def resample_emission_params( seed: Float[Array, "2"], syllables: Int[Array, "n_sequences n_timesteps"], mask: Int[Array, "n_sequences n_timesteps"], states: Int[Array, "n_sequences n_timesteps"], n_states: int, n_syllables: int, emissions_beta: Float, ) -> Float[Array, "n_states n_syllables n_syllables"]: """Resample emission parameters from their posterior distribution. Args: seed: random seed syllables: syllable observations mask: mask of valid observations states: hidden states n_states: number of hidden states n_syllables: number of syllables emission_base_sigma: emission base standard deviation emission_biases_sigma: emission biases standard deviation Returns: emissions: syllable transition probabilities for each state """ sufficient_stats = ( jnp.zeros((n_states, n_syllables, n_syllables)) .at[states[:, 1:], syllables[:, :-1], syllables[:, 1:]] .add(mask[:, 1:]) ) emissions = jax.vmap(jr.dirichlet)( jr.split(seed, n_states * n_syllables), sufficient_stats.reshape(-1, n_syllables) + emissions_beta, ).reshape(n_states, n_syllables, n_syllables) return emissions
[docs] @partial(jax.jit, static_argnums=(3,)) def resample_trans_probs( seed: Float[Array, "2"], mask: Int[Array, "n_sequences n_timesteps"], states: Int[Array, "n_sequences n_timesteps"], n_states: int, beta: Float, kappa: Float, ) -> Float[Array, "n_states n_states"]: """Resample transition probabilities from their posterior distribution. Args: seed: random seed mask: mask of valid observations states: hidden states n_states: number of hidden states beta: Dirichlet concentration parameter kappa: Dirichlet concentration parameter Returns: trans_probs: posterior transition probabilities """ trans_counts = jax.vmap(count_transitions, in_axes=(0, 0, None))(states, mask, n_states).sum(0) trans_probs = jax.vmap(jr.dirichlet)( jr.split(seed, n_states), trans_counts + beta + jnp.eye(n_states) * kappa ) return trans_probs