Standard HHMM

Data:

na

data = {

Functions:

obs_log_likelihoods(data, params)

Compute log likelihoods of observations for each hidden state.

log_params_prob(params, hypparams)

Compute the log probability of the parameters based on their priors.

log_joint_prob(data, params, hypparams)

Compute the log joint probability of the data and parameters.

resample_states(seed, data, params[, parallel])

Resample hidden states from their posterior distribution.

fit_gibbs(data, hypparams, init_params[, ...])

Fit a model using Gibbs sampling.

initialize_params(data, hypparams[, states, ...])

Initialize parameters by sampling from their prior distribution or using provided states.

marginal_loglik(data, params)

Estimate marginal log likelihood of the data

smoothed_states(data, params)

Estimate marginals of hidden states using forward-backward algorithm.

filtered_states(data, params)

Estimate marginals of hidden states using forward-backward algorithm.

predicted_states(data, params)

Predict hidden states using Viterbi algorithm.

random_params(seed, hypparams)

Generate random model parameters.

simulate(seed, params, n_timesteps, n_sequences)

Simulate data from the model.

resample_params(seed, data, params, states, ...)

Resample parameters from their posterior distribution.

resample_emission_params(seed, syllables, ...)

Resample emission parameters from their posterior distribution.

resample_trans_probs(seed, mask, states, ...)

Resample transition probabilities from their posterior distribution.

state_moseq.hhmm_standard.na = None
data = {

“syllables”: (n_sequences, n_timesteps, n_syllables), “mask”: (n_sequences, n_timesteps),

}

states: (n_sequences, n_timesteps)

params = {

“emissions”: (n_states, n_syllables, n_syllables), “trans_probs”: (n_states, n_states),

}

hypparams = {

“n_states”: (,), “emission_beta”: (,), “trans_beta”: (,), “trans_kappa”: (,), “n_syllables”

}

state_moseq.hhmm_standard.obs_log_likelihoods(data, params)[source]

Compute log likelihoods of observations for each hidden state.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_standard.log_params_prob(params, hypparams)[source]

Compute the log probability of the parameters based on their priors.

Parameters:
  • params (dict)

  • hypparams (dict)

Return type:

Float

state_moseq.hhmm_standard.log_joint_prob(data, params, hypparams)[source]

Compute the log joint probability of the data and parameters.

Parameters:
  • data (dict)

  • params (dict)

  • hypparams (dict)

Return type:

Float

state_moseq.hhmm_standard.resample_states(seed, data, params, parallel=False)[source]

Resample hidden states from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • data (dict) – data dictionary

  • params (dict) – parameters dictionary

  • parallel (bool) – whether to use parallel message passing

Returns:

resampled hidden states marginal_loglik: marginal log likelihood of the data

Return type:

states

state_moseq.hhmm_standard.fit_gibbs(data, hypparams, init_params, init_states=None, seed=Array([0, 0], dtype=uint32), num_iters=100, parallel=False)[source]

Fit a model using Gibbs sampling.

Parameters:
  • data (dict) – data dictionary

  • hypparams (dict) – hyperparameters dictionary

  • init_params (dict) – initial parameters directionary

  • init_states (Int[Array, 'n_sequences n_timesteps'] | None) – initial hidden states (optional)

  • seed (Float[Array, '2']) – random seed

  • num_iters (Int) – number of iterations

  • parallel (bool) – whether to use parallel message passing

Returns:

fitted parameters dictionary log_joints: log joint probability of the data and parameters recorded at each iteration

Return type:

params

state_moseq.hhmm_standard.initialize_params(data, hypparams, states=None, seed=Array([0, 0], dtype=uint32))[source]

Initialize parameters by sampling from their prior distribution or using provided states.

Parameters:
  • data (dict) – data dictionary

  • hypparams (dict) – hyperparameters dictionary

  • states (Int[Array, 'n_sequences n_timesteps'] | None) – states used for initializing the parameters (optional)

  • seed (Float[Array, '2']) – random seed

Return type:

dict

state_moseq.hhmm_standard.marginal_loglik(data, params)[source]

Estimate marginal log likelihood of the data

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_standard.smoothed_states(data, params)[source]

Estimate marginals of hidden states using forward-backward algorithm.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_standard.filtered_states(data, params)[source]

Estimate marginals of hidden states using forward-backward algorithm.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_standard.predicted_states(data, params)[source]

Predict hidden states using Viterbi algorithm.

Parameters:
  • data (dict)

  • params (dict)

Return type:

Float[Array, ‘n_sequences n_timesteps n_states’]

state_moseq.hhmm_standard.random_params(seed, hypparams)[source]

Generate random model parameters.

emissions ~ Dirichlet(emissions_beta) (for each state and syllable) trans_probs ~ Dirichlet(trans_beta + trans_kappa * I)

Parameters:
  • seed (Float[Array, '2']) – random seed

  • hypparams (dict) – hyperparameters dictionary

Returns:

parameters dictionary

Return type:

params

state_moseq.hhmm_standard.simulate(seed, params, n_timesteps, n_sequences)[source]

Simulate data from the model.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • params (dict) – parameters dictionary

  • n_timesteps (Int) – number of timesteps to simulate

  • n_sequences (Int) – number of sessions to simulate

Returns:

simulated states syllables: simulated syllables

Return type:

states

state_moseq.hhmm_standard.resample_params(seed, data, params, states, hypparams)[source]

Resample parameters from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • data (dict) – data dictionary

  • params (dict) – parameters dictionary

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • hypparams (dict) – hyperparameters dictionary

Returns:

parameters dictionary

Return type:

params

state_moseq.hhmm_standard.resample_emission_params(seed, syllables, mask, states, n_states, n_syllables, emissions_beta)[source]

Resample emission parameters from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • syllables (Int[Array, 'n_sequences n_timesteps']) – syllable observations

  • mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • n_states (int) – number of hidden states

  • n_syllables (int) – number of syllables

  • emission_base_sigma – emission base standard deviation

  • emission_biases_sigma – emission biases standard deviation

  • emissions_beta (Float)

Returns:

syllable transition probabilities for each state

Return type:

emissions

state_moseq.hhmm_standard.resample_trans_probs(seed, mask, states, n_states, beta, kappa)[source]

Resample transition probabilities from their posterior distribution.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • mask (Int[Array, 'n_sequences n_timesteps']) – mask of valid observations

  • states (Int[Array, 'n_sequences n_timesteps']) – hidden states

  • n_states (int) – number of hidden states

  • beta (Float) – Dirichlet concentration parameter

  • kappa (Float) – Dirichlet concentration parameter

Returns:

posterior transition probabilities

Return type:

trans_probs