Utility Functions
Functions:
|
Vectorized version of logits_to_probs. |
|
Vectorized version of probs_to_logits. |
|
Sample posterior mean and variance given normal-inverse gamma prior. |
Generate an orthonormal matrix that embeds R^(n-1) into the space of 0-sum vectors in R^n. |
|
|
Lower dimension in specified axis by projecting onto the space of 0-sum vectors. |
|
Raise dimension in specified axis by embedding into the space of 0-sum vectors. |
|
Simulate a state sequence from in Markov chain. |
|
Count transitions between states. |
|
Compare high-level state sequences. |
|
Sample using Hamiltonian Monte Carlo. |
|
Sample using Laplace approximation. |
|
Symmetrize a matrix by averaging it with its transpose. |
|
Solve the linear system Ax = B, where A is a positive semi-definite matrix. |
|
Compute the inverse of a positive semi-definite matrix using Cholesky decomposition. |
|
Compute cross-sequence mutual information. |
|
Compute mutual information at a range of lags for real data and equivalent Markov chains. |
|
Save a dict of pytrees to an hdf5 file. |
|
Load a dict of pytrees from an hdf5 file. |
|
Invert |
|
Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. |
|
Get durations of high-level states. |
|
Randomly sample instances of each state. |
|
Get frequencies for a batch of high-level state sequences. |
|
Compute the adjusted Rand index between two sets of high-level state sequences. |
- state_moseq.util.logits_to_probs(logits)[source]
Vectorized version of logits_to_probs. Takes similar arguments as logits_to_probs but with additional array axes over which logits_to_probs is mapped.
Original documentation:
Convert logits to probabilities.
- Parameters:
logits (Float[Array, 'n_categories-1'])
- Return type:
Float[Array, ‘n_categories’]
- state_moseq.util.probs_to_logits(probs, pseudo_count=1e-08)[source]
Vectorized version of probs_to_logits. Takes similar arguments as probs_to_logits but with additional array axes over which probs_to_logits is mapped.
Original documentation:
Convert probabilities to logits.
- Parameters:
probs (Float[Array, 'n_categories'])
pseudo_count (Float)
- Return type:
Float[Array, ‘n_categories-1’]
- state_moseq.util.normal_inverse_gamma_posterior(seed, mean, sigmasq, n, lambda_, alpha, beta)[source]
Sample posterior mean and variance given normal-inverse gamma prior.
- Parameters:
seed (Float[Array, '2']) – random seed
mean (Float) – sample mean
sigmasq (Float) – sample variance
n (Int) – number of data points
lambda – strength of prior
alpha (Float) – inverse gamma shape parameter
beta (Float) – inverse gamma rate parameter
lambda_ (Float)
- Returns:
posterior mean sigma: posterior variance
- Return type:
mu
- state_moseq.util.center_embedding(n)[source]
Generate an orthonormal matrix that embeds R^(n-1) into the space of 0-sum vectors in R^n.
- Parameters:
n (int)
- Return type:
Float[Array, ‘n n-1’]
- state_moseq.util.lower_dim(arr, axis=0)[source]
Lower dimension in specified axis by projecting onto the space of 0-sum vectors.
- state_moseq.util.raise_dim(arr, axis=0)[source]
Raise dimension in specified axis by embedding into the space of 0-sum vectors.
- state_moseq.util.simulate_markov_chain(seed, trans_probs, n_timesteps, init_probs=None)[source]
Simulate a state sequence from in Markov chain.
- Parameters:
seed (Float[Array, '2']) – random seed
trans_probs (Float[Array, 'n_states n_states'] | Float[Array, 'n_timesteps n_states n_states']) – transition probabilities between states
n_timesteps (Int) – number of timesteps to simulate
init_probs (Float[Array, 'n_states'] | None) – initial state probabilities. If None, uniform distribution is used.
- Returns:
simulated state sequence
- Return type:
states
- state_moseq.util.count_transitions(states, mask, n_states)[source]
Count transitions between states.
- Parameters:
states (Int[Array, 'n_timesteps']) – discrete state sequence
mask (Int[Array, 'n_timesteps']) – mask of valid observations
n_states (int) – number of discrete states
- Returns:
transition counts
- Return type:
trans_counts
- state_moseq.util.compare_states(states1, states2, n_states=None)[source]
Compare high-level state sequences.
- Parameters:
states1 (Int[Array, 'n_timesteps'] | Dict[str, Int[Array, 'n_timesteps']]) – first set of state sequences (can be an array or dictionary of sequences)
states2 (Int[Array, 'n_timesteps'] | Dict[str, Int[Array, 'n_timesteps']]) – second set of state sequences (can be an array or dictionary of sequences)
n_states (Int | None) – number of discrete states. If None, inferred from data.
- Returns:
confusion matrix optimal_permutation: optimal permutation of first set of states to match second set accuracy: proportion of timepoints with matching labels (after optimal permutation)
- Return type:
confusion_matrix
- state_moseq.util.sample_hmc(seed, log_prob_fn, init_params, num_leapfrog_steps=3, step_size=0.001, num_results=1, num_burnin_steps=100)[source]
Sample using Hamiltonian Monte Carlo.
- Parameters:
seed (Float[Array, '2'])
log_prob_fn (Callable)
init_params (PyTree)
num_leapfrog_steps (Int)
step_size (Float)
num_results (Int)
num_burnin_steps (Int)
- Return type:
Tuple[PyTree, PyTree]
- state_moseq.util.sample_laplace(seed, log_prob_fn, init_params, gradient_descent_iters=200, gradient_descent_lr=0.01)[source]
Sample using Laplace approximation. Uses gradient descent to find mode of posterior.
- Parameters:
seed (Float[Array, '2']) – random seed
log_prob_fn (Callable) – log probability function
init_params (PyTree) – initial parameters
gradient_descent_iters (Int) – number of gradient descent iterations
gradient_descent_lr (Float) – gradient descent learning rate
- Returns:
sampled parameters losses: loss history
- Return type:
params
- state_moseq.util.symmetrize(A)[source]
Symmetrize a matrix by averaging it with its transpose.
- Parameters:
A (Float[Array, 'n n'])
- Return type:
Float[Array, ‘n n’]
- state_moseq.util.psd_solve(A, B, diagonal_boost=1e-06)[source]
Solve the linear system Ax = B, where A is a positive semi-definite matrix. :param A: positive semi-definite matrix :param B: right-hand side matrix :param diagonal_boost: boost to diagonal to ensure positive definiteness
- Returns:
solution to the linear system
- Return type:
x
- Parameters:
A (Float[Array, 'n n'])
B (Float[Array, 'n m'])
diagonal_boost (float)
- state_moseq.util.psd_inv(A, diagonal_boost=1e-06)[source]
Compute the inverse of a positive semi-definite matrix using Cholesky decomposition. :param A: positive semi-definite matrix :param diagonal_boost: boost to diagonal to ensure positive definiteness
- Returns:
inverse of the matrix
- Return type:
Ainv
- Parameters:
A (Float[Array, 'n n'])
diagonal_boost (float)
- state_moseq.util.cross_sequence_mutual_information(sequence1, sequence2, mask, n_categories, pseudo_count=1e-08)[source]
Compute cross-sequence mutual information.
- Parameters:
sequence1 (Int[Array, 'n_timesteps']) – first sequence
sequence2 (Int[Array, 'n_timesteps']) – second sequence
mask (Bool[Array, 'n_timesteps']) – mask for valid timesteps
n_categories (int) – number of categories
pseudo_count (float) – pseudo count to add to probabilities
- Returns:
mutual information
- Return type:
mi
- state_moseq.util.lagged_mutual_information(sequences, mask, lags, pseudo_count=1e-08)[source]
Compute mutual information at a range of lags for real data and equivalent Markov chains.
- Parameters:
sequences (Int[Array, 'n_sequences n_timesteps']) – sequences from which to compute mutual information
mask (Bool[Array, 'n_sequences n_timesteps']) – mask indicating valid timesteps
lags (Int[Array, 'n_lags']) – array of temporal lags
pseudo_count (float) – pseudo count to use when computing mutual information
- Returns:
mutual information for each sequence at each lag markov_mi: mutual information for equivalent Markov chains at each lag shuff_mi: mutual information across randomly paired sequences at each lag
- Return type:
real_mi
- state_moseq.util.save_hdf5(filepath, save_dict, datapath=None, overwrite_results=False)[source]
Save a dict of pytrees to an hdf5 file. The leaves of the pytrees must be numpy arrays, scalars, or strings.
- Parameters:
filepath (str) – Path of the hdf5 file to create.
save_dict (Dict[str, PyTree]) – Dictionary where the values are pytrees.
datapath (str | None) – Path within hdf5 file to save the data. If None, data are saved at the root.
overwrite_results (bool)
- Return type:
None
- state_moseq.util.load_hdf5(filepath, datapath=None)[source]
Load a dict of pytrees from an hdf5 file.
- Parameters:
filepath (str) – Path of the hdf5 file to load.
datapath (str | None) – Path within hdf5 file to load data from. If None, loads from the root.
- Returns:
Dictionary where the values are pytrees.
- Return type:
save_dict
- state_moseq.util.unbatch(data, keys, bounds)[source]
Invert
state_moseq.util.batch()- Parameters:
data (Array) – Stack of segmented time-series, shape (n_segs, seg_length, …).
keys (list[str] | Array) – Name of the time-series that each segment came from
bounds (Int[Array, 'n_segs 2']) – Start and end indices for each segment.
- Returns:
Dictionary mapping names to reconstructed time-series.
- Return type:
data_dict
- state_moseq.util.batch(data_dict, keys=None, seg_length=None, seg_overlap=30)[source]
Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. The data is padded so that the stacked array isn’t ragged. The padding repeats the last frame of each time-series until the end of the segment.
- Parameters:
data_dict (Dict[str, Array]) – Dictionary of time-series, each of shape (T, …).
keys (list[str] | None) – Optional list of keys to control order and inclusion of time-series.
seg_length (int | None) – Length of each segment. Defaults to max sequence length.
seg_overlap (int) – Overlap between segments in frames.
- Returns:
Stacked data array, shape (N, seg_length, …). mask: Binary mask for valid data (1 = valid, 0 = padding), shape (N, seg_length). metadata: Tuple (keys, bounds), identifying sources and segment positions.
- Return type:
data
- state_moseq.util.get_durations(states_dict)[source]
Get durations of high-level states.
- Parameters:
states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of high-level state sequences.
- Returns:
Times between high-level state transitions (across all sequences).
- Return type:
durations
Examples
>>> states_dict = { ... 'name1': np.array([1, 1, 2, 2, 2, 3]), ... 'name2': np.array([0, 0, 0, 1]), ... } >>> get_durations(states_dict) array([2, 3, 1, 3, 1])
- state_moseq.util.sample_instances(states_dict, num_instances)[source]
Randomly sample instances of each state.
- Parameters:
states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of state sequences.
num_instances (int) – Number of instances per state.
- Returns:
Dictionary mapping state index to instances.
- Return type:
sampled_instances
- state_moseq.util.get_frequencies(states_dict, num_states=None, runlength=False)[source]
Get frequencies for a batch of high-level state sequences.
- Parameters:
states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of high-level state sequences.
num_states (int | None) – Total number of states. If None, inferred from data.
runlength (bool) – If True, count only the first timepoint of each run of a state.
- Returns:
Frequency of each state across all state sequences
- Return type:
frequencies
Examples
>>> states_dict = { 'name1': np.array([1, 1, 2, 2, 2, 3]), 'name2': np.array([0, 0, 0, 1])} >>> get_frequencies(states_dict, runlength=True) array([0.2, 0.4, 0.2, 0.2]) >>> get_frequencies(states_dict, runlength=False) array([0.3, 0.3, 0.3, 0.1])
- state_moseq.util.get_adjusted_rand(states_dict1, states_dict2, downsample=10)[source]
Compute the adjusted Rand index between two sets of high-level state sequences. :param states_dict1: First dictionary of high-level state sequences. :param states_dict2: Second dictionary of high-level state sequences. :param downsample: Downsampling factor to reduce the length of the sequences.
- Returns:
Adjusted Rand index between the two sets of state sequences.
- Return type:
adjusted_rand_index
- Parameters:
states_dict1 (Dict[str, Int[Array, 'n_timesteps']])
states_dict2 (Dict[str, Int[Array, 'n_timesteps']])
downsample (int)