Efficient HHMM

Functions:

estimate_emission_params(sufficient_stats)

Estimate emission parameters from transition counts.

get_syllable_trans_probs(emission_base, ...)

Compute transition probabilities between syllables.

obs_log_likelihoods(data, params)

Compute log likelihoods of observations for each hidden state.

log_params_prob(params, hypparams)

Compute the log probability of the parameters based on their priors.

log_joint_prob(data, params, hypparams)

Compute the log joint probability of the data and parameters.

resample_states(seed, data, params[, parallel])

Resample hidden states from their posterior distribution.

fit_gibbs(data, hypparams, init_params[, ...])

Fit a model using Gibbs sampling.

initialize_params(data, hypparams[, states, ...])

Initialize parameters by sampling from their prior distribution or using provided states.

fit_gradient_descent(data, hypparams, ...[, ...])

Fit a model using gradient descent.

marginal_loglik(data, params[, parallel])

Estimate marginal log likelihood of the data

smoothed_states(data, params[, parallel])

Estimate marginals of hidden states using forward-backward algorithm.

filtered_states(data, params[, parallel])

Estimate marginals of hidden states using forward-backward algorithm.

predicted_states(data, params)

Predict hidden states using Viterbi algorithm.

random_params(seed, hypparams)

Generate random model parameters.

simulate(seed, params, n_timesteps, n_sequences)

Simulate data from the model.

resample_params(seed, data, states, hypparams)

Resample parameters from their posterior distribution.

resample_emission_params(seed, syllables, ...)

Resample emission parameters from their posterior distribution.

resample_trans_probs(seed, mask, states, ...)

Resample transition probabilities from their posterior distribution.

state_moseq.hhmm_efficient.estimate_emission_params(sufficient_stats)[source]

Estimate emission parameters from transition counts.

Parameters:

sufficient_stats (Float[Array, 'n_states n_syllables n_syllables'])

Return type:

Tuple[Float[Array, ‘n_syllables n_syllables-1’], Float[Array, ‘n_states-1 n_syllables-1’]]

state_moseq.hhmm_efficient.get_syllable_trans_probs(emission_base, emission_biases)[source]

Compute transition probabilities between syllables.

Parameters:
  • emission_base (Float[Array, 'n_syllables n_syllables-1'])

  • emission_biases (Float[Array, 'n_states-1 n_syllables-1'])

Return type:

Float[Array, ‘n_states n_syllables n_syllables’]

state_moseq.hhmm_efficient.obs_log_likelihoods(data, params)[source]

Compute log likelihoods of observations for each hidden state.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_efficient.log_params_prob(params, hypparams)[source]

Compute the log probability of the parameters based on their priors.

Parameters:
  • params (dict)

  • hypparams (dict)

Return type:

Float

state_moseq.hhmm_efficient.log_joint_prob(data, params, hypparams)[source]

Compute the log joint probability of the data and parameters.

Parameters:
  • data (dict)

  • params (dict)

  • hypparams (dict)

Return type:

Float

state_moseq.hhmm_efficient.resample_states(seed, data, params, parallel=False)[source]

Resample hidden states from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • data (dict) – data dictionary

  • params (dict) – parameters dictionary

  • parallel (bool) – whether to use parallel message passing

Returns:

resampled hidden states marginal_loglik: marginal log likelihood of the data

Return type:

states

state_moseq.hhmm_efficient.fit_gibbs(data, hypparams, init_params, init_states=None, seed=Array([0, 0], dtype=uint32), num_iters=100, parallel=False)[source]

Fit a model using Gibbs sampling.

Parameters:
  • data (dict) – data dictionary

  • hypparams (dict) – hyperparameters dictionary

  • init_params (dict) – initial parameters directionary

  • init_states (Int[Array, 'n_sequences n_timesteps'] | None) – initial hidden states (optional)

  • seed (Float[Array, '2']) – random seed

  • num_iters (Int) – number of iterations

  • parallel (bool) – whether to use parallel message passing

Returns:

fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration

Return type:

params

state_moseq.hhmm_efficient.initialize_params(data, hypparams, states=None, seed=Array([0, 0], dtype=uint32))[source]

Initialize parameters by sampling from their prior distribution or using provided states.

Parameters:
  • data (dict) – data dictionary

  • hypparams (dict) – hyperparameters dictionary

  • states (Int[Array, 'n_sequences n_timesteps'] | None) – states used for initializing the parameters (optional)

  • seed (Float[Array, '2']) – random seed

Return type:

dict

state_moseq.hhmm_efficient.fit_gradient_descent(data, hypparams, init_params, num_iters=100, learning_rate=0.001)[source]

Fit a model using gradient descent.

Parameters:
  • data (dict) – data dictionary

  • hypparams (dict) – hyperparameters dictionary

  • init_params (dict) – initial parameters directionary

  • num_iters (Int) – number of iterations

  • learning_rate (Float) – learning rate for gradient descent

Returns:

fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration

Return type:

params

state_moseq.hhmm_efficient.marginal_loglik(data, params, parallel=False)[source]

Estimate marginal log likelihood of the data

Parameters:
  • data (dict)

  • params (dict)

  • parallel (bool)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_efficient.smoothed_states(data, params, parallel=False)[source]

Estimate marginals of hidden states using forward-backward algorithm.

Parameters:
  • data (dict)

  • params (dict)

  • parallel (bool)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_efficient.filtered_states(data, params, parallel=False)[source]

Estimate marginals of hidden states using forward-backward algorithm.

Parameters:
  • data (dict)

  • params (dict)

  • parallel (bool)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_efficient.predicted_states(data, params)[source]

Predict hidden states using Viterbi algorithm.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps’]

state_moseq.hhmm_efficient.random_params(seed, hypparams)[source]

Generate random model parameters.

emission_base ~ Normal(0, emission_base_sigma) emission_biases ~ Normal(0, emission_biases_sigma) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I)

Parameters:
  • seed (Float[Array, '2']) – random seed

  • hypparams (dict) – hyperparameters dictionary

Returns:

parameters dictionary

Return type:

params

state_moseq.hhmm_efficient.simulate(seed, params, n_timesteps, n_sequences)[source]

Simulate data from the model.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • params (dict) – parameters dictionary

  • n_timesteps (Int) – number of timesteps to simulate

  • n_sequences (Int) – number of sessions to simulate

Returns:

simulated states syllables: simulated syllables

Return type:

states

state_moseq.hhmm_efficient.resample_params(seed, data, states, hypparams, params=None)[source]

Resample parameters from their posterior distribution. Emission parameters are resampled using a Laplace approximation; the mode is found using gradient descent.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • data (dict) – data dictionary

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • hypparams (dict) – hyperparameters dictionary

  • params (dict | None) – parameters dictionary (optional; used for initializing gradient descent)

Returns:

parameters dictionary losses: losses recorded during gradient descent

Return type:

params

state_moseq.hhmm_efficient.resample_emission_params(seed, syllables, mask, states, n_states, n_syllables, emission_base_sigma, emission_biases_sigma, gradient_descent_iters=100, gradient_descent_lr=0.001, init_emission_base=None, init_emission_biases=None)[source]

Resample emission parameters from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • syllables (Int[Array, 'n_sequences n_timesteps']) – syllable observations

  • mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • n_states (int) – number of hidden states

  • n_syllables (int) – number of syllables

  • emission_base_sigma (Float) – emission base standard deviation

  • emission_biases_sigma (Float) – emission biases standard deviation

  • gradient_descent_iters (Int) – number of gradient descent iterations

  • gradient_descent_lr (Float) – gradient descent loss rate

  • init_emission_base (Float[Array, 'n_syllables n_syllables-1']) – initial emission base parameters (optional)

  • init_emission_biases (Float[Array, 'n_states-1 n_syllables-1']) – initial emission biases parameters (optional)

Returns:

posterior emission base parameters emission_biases: posterior emission biases parameters losses: losses recorded during gradient descent

Return type:

emission_base

state_moseq.hhmm_efficient.resample_trans_probs(seed, mask, states, n_states, beta, kappa)[source]

Resample transition probabilities from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • n_states (int) – number of hidden states

  • beta (Float) – Dirichlet concentration parameter

  • kappa (Float) – Dirichlet concentration parameter

Returns:

posterior transition probabilities

Return type:

trans_probs