Standard HHMM
Data:
data = { |
Functions:
|
Compute log likelihoods of observations for each hidden state. |
|
Compute the log probability of the parameters based on their priors. |
|
Compute the log joint probability of the data and parameters. |
|
Resample hidden states from their posterior distribution. |
|
Fit a model using Gibbs sampling. |
|
Initialize parameters by sampling from their prior distribution or using provided states. |
|
Estimate marginal log likelihood of the data |
|
Estimate marginals of hidden states using forward-backward algorithm. |
|
Estimate marginals of hidden states using forward-backward algorithm. |
|
Predict hidden states using Viterbi algorithm. |
|
Generate random model parameters. |
|
Simulate data from the model. |
|
Resample parameters from their posterior distribution. |
|
Resample emission parameters from their posterior distribution. |
|
Resample transition probabilities from their posterior distribution. |
- state_moseq.hhmm_standard.na = None
- data = {
“syllables”: (n_sequences, n_timesteps, n_syllables), “mask”: (n_sequences, n_timesteps),
}
states: (n_sequences, n_timesteps)
- params = {
“emissions”: (n_states, n_syllables, n_syllables), “trans_probs”: (n_states, n_states),
}
- hypparams = {
“n_states”: (,), “emission_beta”: (,), “trans_beta”: (,), “trans_kappa”: (,), “n_syllables”
}
- state_moseq.hhmm_standard.obs_log_likelihoods(data, params)[source]
Compute log likelihoods of observations for each hidden state.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_standard.log_params_prob(params, hypparams)[source]
Compute the log probability of the parameters based on their priors.
- Parameters:
params (dict)
hypparams (dict)
- Return type:
Float
- state_moseq.hhmm_standard.log_joint_prob(data, params, hypparams)[source]
Compute the log joint probability of the data and parameters.
- Parameters:
data (dict)
params (dict)
hypparams (dict)
- Return type:
Float
- state_moseq.hhmm_standard.resample_states(seed, data, params, parallel=False)[source]
Resample hidden states from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
data (dict) – data dictionary
params (dict) – parameters dictionary
parallel (bool) – whether to use parallel message passing
- Returns:
resampled hidden states marginal_loglik: marginal log likelihood of the data
- Return type:
states
- state_moseq.hhmm_standard.fit_gibbs(data, hypparams, init_params, init_states=None, seed=Array([0, 0], dtype=uint32), num_iters=100, parallel=False)[source]
Fit a model using Gibbs sampling.
- Parameters:
data (dict) – data dictionary
hypparams (dict) – hyperparameters dictionary
init_params (dict) – initial parameters directionary
init_states (Int[Array, 'n_sequences n_timesteps'] | None) – initial hidden states (optional)
seed (Float[Array, '2']) – random seed
num_iters (Int) – number of iterations
parallel (bool) – whether to use parallel message passing
- Returns:
fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration
- Return type:
params
- state_moseq.hhmm_standard.initialize_params(data, hypparams, states=None, seed=Array([0, 0], dtype=uint32))[source]
Initialize parameters by sampling from their prior distribution or using provided states.
- Parameters:
data (dict) – data dictionary
hypparams (dict) – hyperparameters dictionary
states (Int[Array, 'n_sequences n_timesteps'] | None) – states used for initializing the parameters (optional)
seed (Float[Array, '2']) – random seed
- Return type:
dict
- state_moseq.hhmm_standard.marginal_loglik(data, params)[source]
Estimate marginal log likelihood of the data
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_standard.smoothed_states(data, params)[source]
Estimate marginals of hidden states using forward-backward algorithm.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_standard.filtered_states(data, params)[source]
Estimate marginals of hidden states using forward-backward algorithm.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_standard.predicted_states(data, params)[source]
Predict hidden states using Viterbi algorithm.
- Parameters:
data (dict)
params (dict)
- Return type:
Float[Array, ‘n_sequences n_timesteps n_states’]
- state_moseq.hhmm_standard.random_params(seed, hypparams)[source]
Generate random model parameters.
emissions ~ Dirichlet(emissions_beta) (for each state and syllable) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I)
- Parameters:
seed (Float[Array, '2']) – random seed
hypparams (dict) – hyperparameters dictionary
- Returns:
parameters dictionary
- Return type:
params
- state_moseq.hhmm_standard.simulate(seed, params, n_timesteps, n_sequences)[source]
Simulate data from the model.
- Parameters:
seed (Float[Array, '2']) – random seed
params (dict) – parameters dictionary
n_timesteps (Int) – number of timesteps to simulate
n_sequences (Int) – number of sessions to simulate
- Returns:
simulated states syllables: simulated syllables
- Return type:
states
- state_moseq.hhmm_standard.resample_params(seed, data, params, states, hypparams)[source]
Resample parameters from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
data (dict) – data dictionary
params (dict) – parameters dictionary
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
hypparams (dict) – hyperparameters dictionary
- Returns:
parameters dictionary
- Return type:
params
- state_moseq.hhmm_standard.resample_emission_params(seed, syllables, mask, states, n_states, n_syllables, emissions_beta)[source]
Resample emission parameters from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
syllables (Int[Array, 'n_sequences n_timesteps']) – syllable observations
mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
n_states (int) – number of hidden states
n_syllables (int) – number of syllables
emission_base_sigma – emission base standard deviation
emission_biases_sigma – emission biases standard deviation
emissions_beta (Float)
- Returns:
syllable transition probabilities for each state
- Return type:
emissions
- state_moseq.hhmm_standard.resample_trans_probs(seed, mask, states, n_states, beta, kappa)[source]
Resample transition probabilities from their posterior distribution.
- Parameters:
seed (Float[Array, '2']) – random seed
mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations
states (Int[Array, 'n_sequences n_timesteps']) – hidden states
n_states (int) – number of hidden states
beta (Float) – Dirichlet concentration parameter
kappa (Float) – Dirichlet concentration parameter
- Returns:
posterior transition probabilities
- Return type:
trans_probs