Source code for state_moseq.hhmm_efficient

import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import norm, dirichlet
from jaxtyping import Array, Float, Int
from jax.nn import softmax
import tqdm
from typing import Tuple
from functools import partial
from dynamax.hidden_markov_model import (
    hmm_filter,
    hmm_smoother,
    hmm_posterior_mode,
    hmm_posterior_sample,
    parallel_hmm_filter,
    parallel_hmm_smoother,
    parallel_hmm_posterior_sample,
)

from .util import (
    simulate_markov_chain,
    lower_dim,
    raise_dim,
    sample_laplace,
    count_transitions
)

na = jnp.newaxis


[docs] def estimate_emission_params( sufficient_stats: Float[Array, "n_states n_syllables n_syllables"], ) -> Tuple[ Float[Array, "n_syllables n_syllables-1"], Float[Array, "n_states-1 n_syllables-1"] ]: """Estimate emission parameters from transition counts.""" logits = jnp.log(sufficient_stats + 1e-2) emission_base_est = lower_dim(logits.mean(0), 1) emission_biases_est = (logits - logits.mean(0)[na]).mean(1) emission_biases_est = lower_dim(lower_dim(emission_biases_est, 1), 0) return emission_base_est, emission_biases_est
[docs] def get_syllable_trans_probs( emission_base: Float[Array, "n_syllables n_syllables-1"], emission_biases: Float[Array, "n_states-1 n_syllables-1"], ) -> Float[Array, "n_states n_syllables n_syllables"]: """Compute transition probabilities between syllables.""" emission_base = raise_dim(emission_base, 1) emission_biases = raise_dim(raise_dim(emission_biases, 0), 1) logits = emission_base[na] + emission_biases[:, na] return softmax(logits, axis=-1)
[docs] def obs_log_likelihoods( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Compute log likelihoods of observations for each hidden state.""" n_sequences = data["syllables"].shape[0] n_states = params["trans_probs"].shape[0] log_syllable_trans_probs = jnp.log( get_syllable_trans_probs( params["emission_base"], params["emission_biases"], ) ) log_likelihoods = jax.vmap( lambda T: T[data["syllables"][:, :-1], data["syllables"][:, 1:]] )(log_syllable_trans_probs) log_likelihoods = jnp.concatenate( [jnp.zeros((n_states, n_sequences, 1)), log_likelihoods], axis=2 ) return log_likelihoods.transpose((1, 2, 0)) * data["mask"][:, :, na]
[docs] def log_params_prob( params: dict, hypparams: dict, ) -> Float: """Compute the log probability of the parameters based on their priors.""" n_states = params["trans_probs"].shape[0] # prior on emission base parameters emission_base_log_prob = norm.logpdf( params["emission_base"] / hypparams["emission_base_sigma"] ).sum() # prior on emission bias parameters emission_biases_log_prob = norm.logpdf( params["emission_biases"] / hypparams["emission_biases_sigma"] ).sum() # prior on transition parameters trans_probs_log_prob = jax.vmap(dirichlet.logpdf)( params["trans_probs"], jnp.eye(n_states) * hypparams["trans_kappa"] + hypparams["trans_beta"], ).sum() return emission_base_log_prob + emission_biases_log_prob + trans_probs_log_prob
[docs] def log_joint_prob( data: dict, params: dict, hypparams: dict, ) -> Float: """Compute the log joint probability of the data and parameters.""" return marginal_loglik(data, params) + log_params_prob(params, hypparams)
[docs] @partial(jax.jit, static_argnums=(3,)) def resample_states( seed: Float[Array, "2"], data: dict, params: dict, parallel: bool = False, ) -> Tuple[Int[Array, "n_sequences n_timesteps"], Float]: """Resample hidden states from their posterior distribution. Args: seed: random seed data: data dictionary params: parameters dictionary parallel: whether to use parallel message passing Returns: states: resampled hidden states marginal_loglik: marginal log likelihood of the data """ n_states = params["trans_probs"].shape[0] seeds = jr.split(seed, data["syllables"].shape[0]) sample_fn = parallel_hmm_posterior_sample if parallel else hmm_posterior_sample marginal_logliks, states = jax.vmap(sample_fn, in_axes=(0, None, None, 0))( seeds, jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ) return states, marginal_logliks.sum()
[docs] def fit_gibbs( data: dict, hypparams: dict, init_params: dict, init_states: Int[Array, "n_sequences n_timesteps"] = None, seed: Float[Array, "2"] = jr.PRNGKey(0), num_iters: Int = 100, parallel: bool = False, ) -> Tuple[dict, Float[Array, "num_iters"]]: """Fit a model using Gibbs sampling. Args: data: data dictionary hypparams: hyperparameters dictionary init_params: initial parameters directionary init_states: initial hidden states (optional) seed: random seed num_iters: number of iterations parallel: whether to use parallel message passing Returns: params: fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration """ if init_states is None: states, _ = resample_states(seed, data, init_params, parallel) log_joints = [] params = init_params for _ in tqdm.trange(num_iters): seed, subseed = jr.split(seed) params, gd_losses = resample_params(subseed, data, params, states, hypparams) states, marginal_loglik = resample_states(seed, data, params, parallel) log_joints.append(marginal_loglik + log_params_prob(params, hypparams)) return params, states, jnp.array(log_joints)
[docs] def initialize_params( data: dict, hypparams: dict, states: Int[Array, "n_sequences n_timesteps"] = None, seed: Float[Array, "2"] = jr.PRNGKey(0), ) -> dict: """Initialize parameters by sampling from their prior distribution or using provided states. Args: data: data dictionary hypparams: hyperparameters dictionary states: states used for initializing the parameters (optional) seed: random seed """ if states is not None: params = resample_params(seed, data, states, hypparams) else: params = random_params(seed, hypparams) return params
[docs] def fit_gradient_descent( data: dict, hypparams: dict, init_params: dict, num_iters: Int = 100, learning_rate: Float = 1e-3, ) -> Tuple[dict, Float[Array, "num_iters"]]: """Fit a model using gradient descent. Args: data: data dictionary hypparams: hyperparameters dictionary init_params: initial parameters directionary num_iters: number of iterations learning_rate: learning rate for gradient descent Returns: params: fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration """ loss_fn = lambda params: -log_joint_prob(data, params, hypparams) params, losses = gradient_descent(loss_fn, init_params, learning_rate, num_iters) log_joints = -losses return params, log_joints
[docs] def marginal_loglik( data: dict, params: dict, parallel: bool = False, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginal log likelihood of the data""" filter_fn = parallel_hmm_filter if parallel else hmm_filter n_states = params["trans_probs"].shape[0] mll = jax.vmap(filter_fn, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).marginal_loglik.sum() return mll
[docs] def smoothed_states( data: dict, params: dict, parallel: bool = False, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginals of hidden states using forward-backward algorithm.""" smoother_fn = parallel_hmm_smoother if parallel else hmm_smoother n_states = params["trans_probs"].shape[0] return jax.vmap(smoother_fn, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).smoothed_probs
[docs] def filtered_states( data: dict, params: dict, parallel: bool = False, ) -> Float[Array, "n_sequences n_timesteps n_states"]: """Estimate marginals of hidden states using forward-backward algorithm.""" filter_fn = parallel_hmm_filter if parallel else hmm_filter n_states = params["trans_probs"].shape[0] return jax.vmap(filter_fn, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), ).filtered_probs
[docs] def predicted_states( data: dict, params: dict, ) -> Float[Array, "n_sequences n_timesteps"]: """Predict hidden states using Viterbi algorithm.""" n_states = params["trans_probs"].shape[0] return jax.vmap(hmm_posterior_mode, in_axes=(None, None, 0))( jnp.ones(n_states) / n_states, params["trans_probs"], obs_log_likelihoods(data, params), )
[docs] def random_params( seed: Float[Array, "2"], hypparams: dict, ) -> dict: """Generate random model parameters. emission_base ~ Normal(0, emission_base_sigma) emission_biases ~ Normal(0, emission_biases_sigma) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I) Args: seed: random seed hypparams: hyperparameters dictionary Returns: params: parameters dictionary """ n_syllables = hypparams["n_syllables"] n_states = hypparams["n_states"] seeds = jr.split(seed, 3) emission_base = ( jr.normal(seeds[0], shape=(n_syllables, n_syllables - 1)) * hypparams["emission_base_sigma"] ) emission_biases = ( jr.normal(seeds[1], shape=(n_states - 1, n_syllables - 1)) * hypparams["emission_biases_sigma"] ) trans_probs = jax.vmap(jr.dirichlet)( jr.split(seeds[2], n_states), jnp.eye(n_states) * hypparams["trans_kappa"] + hypparams["trans_beta"], ) return { "emission_base": emission_base, "emission_biases": emission_biases, "trans_probs": trans_probs, }
[docs] def simulate( seed: Float[Array, "2"], params: dict, n_timesteps: Int, n_sequences: Int, ) -> Tuple[ Int[Array, "n_sequences n_timesteps"], Int[Array, "n_sequences n_timesteps"] ]: """Simulate data from the model. Args: seed: random seed params: parameters dictionary n_timesteps: number of timesteps to simulate n_sequences: number of sessions to simulate Returns: states: simulated states syllables: simulated syllables """ seeds = jr.split(seed, 3) states = jax.vmap(simulate_markov_chain, in_axes=(0, None, None))( jr.split(seeds[0], n_sequences), params["trans_probs"], n_timesteps, ) syllable_trans_probs = get_syllable_trans_probs( params["emission_base"], params["emission_biases"], )[states] syllables = jax.vmap(simulate_markov_chain, in_axes=(0, 0, None))( jr.split(seeds[1], n_sequences), syllable_trans_probs, n_timesteps ) return states, syllables
[docs] def resample_params( seed: Float[Array, "2"], data: dict, states: Int[Array, "n_sequences n_timesteps"], hypparams: dict, params: dict = None, ) -> Tuple[dict, Float[Array, "gradient_descent_iters"]]: """Resample parameters from their posterior distribution. Emission parameters are resampled using a Laplace approximation; the mode is found using gradient descent. Args: seed: random seed data: data dictionary states: hidden states hypparams: hyperparameters dictionary params: parameters dictionary (optional; used for initializing gradient descent) Returns: params: parameters dictionary losses: losses recorded during gradient descent """ seeds = jr.split(seed, 2) if params is None: init_emission_base = None init_emission_biases = None else: init_emission_base = params["emission_base"] init_emission_biases = params["emission_biases"] emission_params, gd_losses = resample_emission_params( seeds[1], data["syllables"], data["mask"], states, hypparams["n_states"], hypparams["n_syllables"], hypparams["emission_base_sigma"], hypparams["emission_biases_sigma"], hypparams["emission_gd_iters"], hypparams["emission_gd_lr"], init_emission_base, init_emission_biases, ) trans_probs = resample_trans_probs( seeds[2], data["mask"], states, hypparams["n_states"], hypparams["trans_beta"], hypparams["trans_kappa"], ) params = { "emission_base": emission_params[0], "emission_biases": emission_params[1], "trans_probs": trans_probs, } return params, gd_losses
[docs] @partial(jax.jit, static_argnums=(4,5,8,)) def resample_emission_params( seed: Float[Array, "2"], syllables: Int[Array, "n_sequences n_timesteps"], mask: Int[Array, "n_sequences n_timesteps"], states: Int[Array, "n_sequences n_timesteps"], n_states: int, n_syllables: int, emission_base_sigma: Float, emission_biases_sigma: Float, gradient_descent_iters: Int = 100, gradient_descent_lr: Float = 1e-3, init_emission_base: Float[Array, "n_syllables n_syllables-1"] = None, init_emission_biases: Float[Array, "n_states-1 n_syllables-1"] = None, ) -> Tuple[ Tuple[ Float[Array, "n_syllables n_syllables-1"], Float[Array, "n_states-1 n_syllables-1"], ], Float[Array, "gradient_descent_iters"], ]: """Resample emission parameters from their posterior distribution. Args: seed: random seed syllables: syllable observations mask: mask of valid observations states: hidden states n_states: number of hidden states n_syllables: number of syllables emission_base_sigma: emission base standard deviation emission_biases_sigma: emission biases standard deviation gradient_descent_iters: number of gradient descent iterations gradient_descent_lr: gradient descent loss rate init_emission_base: initial emission base parameters (optional) init_emission_biases: initial emission biases parameters (optional) Returns: emission_base: posterior emission base parameters emission_biases: posterior emission biases parameters losses: losses recorded during gradient descent """ sufficient_stats = ( jnp.zeros((n_states, n_syllables, n_syllables)) .at[states[:, 1:], syllables[:, :-1], syllables[:, 1:]] .add(mask[:, 1:]) ) def log_prob_fn(args): emission_base, emission_biases = args syllable_trans_probs = get_syllable_trans_probs(emission_base, emission_biases) emission_base = raise_dim(emission_base, 1) emission_biases = raise_dim(raise_dim(emission_biases, 0), 1) prior_log_prob = ( norm.logpdf(emission_base / emission_base_sigma).sum() + norm.logpdf(emission_biases / emission_biases_sigma).sum() ) syllables_log_prob = (jnp.log(syllable_trans_probs) * sufficient_stats).sum() return prior_log_prob + syllables_log_prob if init_emission_base is None or init_emission_biases is None: init_emission_params = estimate_emission_params(sufficient_stats) else: init_emission_params = (init_emission_base, init_emission_biases) (emission_base, emission_biases), losses = sample_laplace( seed, log_prob_fn, init_emission_params, gradient_descent_iters, gradient_descent_lr, ) return (emission_base, emission_biases), losses
[docs] @partial(jax.jit, static_argnums=(3,)) def resample_trans_probs( seed: Float[Array, "2"], mask: Int[Array, "n_sequences n_timesteps"], states: Int[Array, "n_sequences n_timesteps"], n_states: int, beta: Float, kappa: Float, ) -> Float[Array, "n_states n_states"]: """Resample transition probabilities from their posterior distribution. Args: seed: random seed mask: mask of valid observations states: hidden states n_states: number of hidden states beta: Dirichlet concentration parameter kappa: Dirichlet concentration parameter Returns: trans_probs: posterior transition probabilities """ trans_counts = jax.vmap(count_transitions, in_axes=(0, 0, None))(states, mask, n_states).sum(0) trans_probs = jax.vmap(jr.dirichlet)( jr.split(seed, n_states), trans_counts + beta + jnp.eye(n_states) * kappa ) return trans_probs