Efficient HHMM
Functions:
|
Estimate emission parameters from transition counts. |
|
Compute transition probabilities between syllables. |
|
Compute log likelihoods of observations for each hidden state. |
|
Compute the log probability of the parameters based on their priors. |
|
Compute the log joint probability of the data and parameters. |
|
Resample hidden states from their posterior distribution. |
|
Fit a model using Gibbs sampling. |
|
Initialize parameters by sampling from their prior distribution or using provided states. |
|
Fit a model using gradient descent. |
|
Estimate marginal log likelihood of the data |
|
Estimate marginals of hidden states using forward-backward algorithm. |
|
Estimate marginals of hidden states using forward-backward algorithm. |
|
Predict hidden states using Viterbi algorithm. |
|
Generate random model parameters. |
|
Simulate data from the model. |
|
Resample parameters from their posterior distribution. |
|
Resample emission parameters from their posterior distribution. |
|
Resample transition probabilities from their posterior distribution. |
- state_moseq.hhmm_efficient.estimate_emission_params(sufficient_stats)[source]
Estimate emission parameters from transition counts.
- Parameters:
sufficient_stats (Float[Array, 'n_states n_syllables n_syllables'])
- Return type:
Tuple[Float[Array, ‘n_syllables n_syllables-1’], Float[Array, ‘n_states-1 n_syllables-1’]]
- state_moseq.hhmm_efficient.get_syllable_trans_probs(emission_base, emission_biases)[source]
Compute transition probabilities between syllables.
- Parameters:
emission_base (Float[Array, 'n_syllables n_syllables-1'])
emission_biases (Float[Array, 'n_states-1 n_syllables-1'])
- Return type:
Float[Array, ‘n_states n_syllables n_syllables’]
- state_moseq.hhmm_efficient.obs_log_likelihoods(data, params)[source]
Compute log likelihoods of observations for each hidden state.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_efficient.log_params_prob(params, hypparams)[source]
Compute the log probability of the parameters based on their priors.
- Parameters:
params (dict)
hypparams (dict)
- Return type:
Float
- state_moseq.hhmm_efficient.log_joint_prob(data, params, hypparams)[source]
Compute the log joint probability of the data and parameters.
- Parameters:
data (dict)
params (dict)
hypparams (dict)
- Return type:
Float
- state_moseq.hhmm_efficient.resample_states(seed, data, params, parallel=False)[source]
Resample hidden states from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
data (dict) – data dictionary
params (dict) – parameters dictionary
parallel (bool) – whether to use parallel message passing
- Returns:
resampled hidden states marginal_loglik: marginal log likelihood of the data
- Return type:
states
- state_moseq.hhmm_efficient.fit_gibbs(data, hypparams, init_params, init_states=None, seed=Array([0, 0], dtype=uint32), num_iters=100, parallel=False)[source]
Fit a model using Gibbs sampling.
- Parameters:
data (dict) – data dictionary
hypparams (dict) – hyperparameters dictionary
init_params (dict) – initial parameters directionary
init_states (Int[Array, 'n_sequences n_timesteps'] | None) – initial hidden states (optional)
seed (Float[Array, '2']) – random seed
num_iters (Int) – number of iterations
parallel (bool) – whether to use parallel message passing
- Returns:
fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration
- Return type:
params
- state_moseq.hhmm_efficient.initialize_params(data, hypparams, states=None, seed=Array([0, 0], dtype=uint32))[source]
Initialize parameters by sampling from their prior distribution or using provided states.
- Parameters:
data (dict) – data dictionary
hypparams (dict) – hyperparameters dictionary
states (Int[Array, 'n_sequences n_timesteps'] | None) – states used for initializing the parameters (optional)
seed (Float[Array, '2']) – random seed
- Return type:
dict
- state_moseq.hhmm_efficient.fit_gradient_descent(data, hypparams, init_params, num_iters=100, learning_rate=0.001)[source]
Fit a model using gradient descent.
- Parameters:
data (dict) – data dictionary
hypparams (dict) – hyperparameters dictionary
init_params (dict) – initial parameters directionary
num_iters (Int) – number of iterations
learning_rate (Float) – learning rate for gradient descent
- Returns:
fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration
- Return type:
params
- state_moseq.hhmm_efficient.marginal_loglik(data, params, parallel=False)[source]
Estimate marginal log likelihood of the data
- Parameters:
data (dict)
params (dict)
parallel (bool)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_efficient.smoothed_states(data, params, parallel=False)[source]
Estimate marginals of hidden states using forward-backward algorithm.
- Parameters:
data (dict)
params (dict)
parallel (bool)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_efficient.filtered_states(data, params, parallel=False)[source]
Estimate marginals of hidden states using forward-backward algorithm.
- Parameters:
data (dict)
params (dict)
parallel (bool)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_efficient.predicted_states(data, params)[source]
Predict hidden states using Viterbi algorithm.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps’]
- state_moseq.hhmm_efficient.random_params(seed, hypparams)[source]
Generate random model parameters.
emission_base ~ Normal(0, emission_base_sigma) emission_biases ~ Normal(0, emission_biases_sigma) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I)
- Parameters:
seed (Float[Array, '2']) – random seed
hypparams (dict) – hyperparameters dictionary
- Returns:
parameters dictionary
- Return type:
params
- state_moseq.hhmm_efficient.simulate(seed, params, n_timesteps, n_sequences)[source]
Simulate data from the model.
- Parameters:
seed (Float[Array, '2']) – random seed
params (dict) – parameters dictionary
n_timesteps (Int) – number of timesteps to simulate
n_sequences (Int) – number of sessions to simulate
- Returns:
simulated states syllables: simulated syllables
- Return type:
states
- state_moseq.hhmm_efficient.resample_params(seed, data, states, hypparams, params=None)[source]
Resample parameters from their posterior distribution. Emission parameters are resampled using a Laplace approximation; the mode is found using gradient descent.
- Parameters:
seed (Float[Array, '2']) – random seed
data (dict) – data dictionary
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
hypparams (dict) – hyperparameters dictionary
params (dict | None) – parameters dictionary (optional; used for initializing gradient descent)
- Returns:
parameters dictionary losses: losses recorded during gradient descent
- Return type:
params
- state_moseq.hhmm_efficient.resample_emission_params(seed, syllables, mask, states, n_states, n_syllables, emission_base_sigma, emission_biases_sigma, gradient_descent_iters=100, gradient_descent_lr=0.001, init_emission_base=None, init_emission_biases=None)[source]
Resample emission parameters from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
syllables (Int[Array, 'n_sequences n_timesteps']) – syllable observations
mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
n_states (int) – number of hidden states
n_syllables (int) – number of syllables
emission_base_sigma (Float) – emission base standard deviation
emission_biases_sigma (Float) – emission biases standard deviation
gradient_descent_iters (Int) – number of gradient descent iterations
gradient_descent_lr (Float) – gradient descent loss rate
init_emission_base (Float[Array, 'n_syllables n_syllables-1']) – initial emission base parameters (optional)
init_emission_biases (Float[Array, 'n_states-1 n_syllables-1']) – initial emission biases parameters (optional)
- Returns:
posterior emission base parameters emission_biases: posterior emission biases parameters losses: losses recorded during gradient descent
- Return type:
emission_base
- state_moseq.hhmm_efficient.resample_trans_probs(seed, mask, states, n_states, beta, kappa)[source]
Resample transition probabilities from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
n_states (int) – number of hidden states
beta (Float) – Dirichlet concentration parameter
kappa (Float) – Dirichlet concentration parameter
- Returns:
posterior transition probabilities
- Return type:
trans_probs