Utility Functions

Functions:

logits_to_probs(logits)

Vectorized version of logits_to_probs.

probs_to_logits(probs[, pseudo_count])

Vectorized version of probs_to_logits.

normal_inverse_gamma_posterior(seed, mean, ...)

Sample posterior mean and variance given normal-inverse gamma prior.

center_embedding(n)

Generate an orthonormal matrix that embeds R^(n-1) into the space of 0-sum vectors in R^n.

lower_dim(arr[, axis])

Lower dimension in specified axis by projecting onto the space of 0-sum vectors.

raise_dim(arr[, axis])

Raise dimension in specified axis by embedding into the space of 0-sum vectors.

simulate_markov_chain(seed, trans_probs, ...)

Simulate a state sequence from in Markov chain.

count_transitions(states, mask, n_states)

Count transitions between states.

compare_states(states1, states2[, n_states])

Compare high-level state sequences.

sample_hmc(seed, log_prob_fn, init_params[, ...])

Sample using Hamiltonian Monte Carlo.

sample_laplace(seed, log_prob_fn, init_params)

Sample using Laplace approximation.

symmetrize(A)

Symmetrize a matrix by averaging it with its transpose.

psd_solve(A, B[, diagonal_boost])

Solve the linear system Ax = B, where A is a positive semi-definite matrix.

psd_inv(A[, diagonal_boost])

Compute the inverse of a positive semi-definite matrix using Cholesky decomposition.

cross_sequence_mutual_information(sequence1, ...)

Compute cross-sequence mutual information.

lagged_mutual_information(sequences, mask, lags)

Compute mutual information at a range of lags for real data and equivalent Markov chains.

save_hdf5(filepath, save_dict[, datapath, ...])

Save a dict of pytrees to an hdf5 file.

load_hdf5(filepath[, datapath])

Load a dict of pytrees from an hdf5 file.

unbatch(data, keys, bounds)

Invert state_moseq.util.batch()

batch(data_dict[, keys, seg_length, seg_overlap])

Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments.

get_durations(states_dict)

Get durations of high-level states.

sample_instances(states_dict, num_instances)

Randomly sample instances of each state.

get_frequencies(states_dict[, num_states, ...])

Get frequencies for a batch of high-level state sequences.

get_adjusted_rand(states_dict1, states_dict2)

Compute the adjusted Rand index between two sets of high-level state sequences.

state_moseq.util.logits_to_probs(logits)[source]

Vectorized version of logits_to_probs. Takes similar arguments as logits_to_probs but with additional array axes over which logits_to_probs is mapped.

Original documentation:

Convert logits to probabilities.

Parameters:

logits (Float[Array, 'n_categories-1'])

Return type:

Float[Array, ‘n_categories’]

state_moseq.util.probs_to_logits(probs, pseudo_count=1e-08)[source]

Vectorized version of probs_to_logits. Takes similar arguments as probs_to_logits but with additional array axes over which probs_to_logits is mapped.

Original documentation:

Convert probabilities to logits.

Parameters:
  • probs (Float[Array, 'n_categories'])

  • pseudo_count (Float)

Return type:

Float[Array, ‘n_categories-1’]

state_moseq.util.normal_inverse_gamma_posterior(seed, mean, sigmasq, n, lambda_, alpha, beta)[source]

Sample posterior mean and variance given normal-inverse gamma prior.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • mean (Float) – sample mean

  • sigmasq (Float) – sample variance

  • n (Int) – number of data points

  • lambda – strength of prior

  • alpha (Float) – inverse gamma shape parameter

  • beta (Float) – inverse gamma rate parameter

  • lambda_ (Float)

Returns:

posterior mean sigma: posterior variance

Return type:

mu

state_moseq.util.center_embedding(n)[source]

Generate an orthonormal matrix that embeds R^(n-1) into the space of 0-sum vectors in R^n.

Parameters:

n (int)

Return type:

Float[Array, ‘n n-1’]

state_moseq.util.lower_dim(arr, axis=0)[source]

Lower dimension in specified axis by projecting onto the space of 0-sum vectors.

state_moseq.util.raise_dim(arr, axis=0)[source]

Raise dimension in specified axis by embedding into the space of 0-sum vectors.

state_moseq.util.simulate_markov_chain(seed, trans_probs, n_timesteps, init_probs=None)[source]

Simulate a state sequence from in Markov chain.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • trans_probs (Float[Array, 'n_states n_states'] | Float[Array, 'n_timesteps n_states n_states']) – transition probabilities between states

  • n_timesteps (Int) – number of timesteps to simulate

  • init_probs (Float[Array, 'n_states'] | None) – initial state probabilities. If None, uniform distribution is used.

Returns:

simulated state sequence

Return type:

states

state_moseq.util.count_transitions(states, mask, n_states)[source]

Count transitions between states.

Parameters:
  • states (Int[Array, 'n_timesteps']) – discrete state sequence

  • mask (Int[Array, 'n_timesteps']) – mask of valid observations

  • n_states (int) – number of discrete states

Returns:

transition counts

Return type:

trans_counts

state_moseq.util.compare_states(states1, states2, n_states=None)[source]

Compare high-level state sequences.

Parameters:
  • states1 (Int[Array, 'n_timesteps'] | Dict[str, Int[Array, 'n_timesteps']]) – first set of state sequences (can be an array or dictionary of sequences)

  • states2 (Int[Array, 'n_timesteps'] | Dict[str, Int[Array, 'n_timesteps']]) – second set of state sequences (can be an array or dictionary of sequences)

  • n_states (Int | None) – number of discrete states. If None, inferred from data.

Returns:

confusion matrix optimal_permutation: optimal permutation of first set of states to match second set accuracy: proportion of timepoints with matching labels (after optimal permutation)

Return type:

confusion_matrix

state_moseq.util.sample_hmc(seed, log_prob_fn, init_params, num_leapfrog_steps=3, step_size=0.001, num_results=1, num_burnin_steps=100)[source]

Sample using Hamiltonian Monte Carlo.

Parameters:
  • seed (Float[Array, '2'])

  • log_prob_fn (Callable)

  • init_params (PyTree)

  • num_leapfrog_steps (Int)

  • step_size (Float)

  • num_results (Int)

  • num_burnin_steps (Int)

Return type:

Tuple[PyTree, PyTree]

state_moseq.util.sample_laplace(seed, log_prob_fn, init_params, gradient_descent_iters=200, gradient_descent_lr=0.01)[source]

Sample using Laplace approximation. Uses gradient descent to find mode of posterior.

Parameters:
  • seed (Float[Array, '2']) – random seed

  • log_prob_fn (Callable) – log probability function

  • init_params (PyTree) – initial parameters

  • gradient_descent_iters (Int) – number of gradient descent iterations

  • gradient_descent_lr (Float) – gradient descent learning rate

Returns:

sampled parameters losses: loss history

Return type:

params

state_moseq.util.symmetrize(A)[source]

Symmetrize a matrix by averaging it with its transpose.

Parameters:

A (Float[Array, 'n n'])

Return type:

Float[Array, ‘n n’]

state_moseq.util.psd_solve(A, B, diagonal_boost=1e-06)[source]

Solve the linear system Ax = B, where A is a positive semi-definite matrix. :param A: positive semi-definite matrix :param B: right-hand side matrix :param diagonal_boost: boost to diagonal to ensure positive definiteness

Returns:

solution to the linear system

Return type:

x

Parameters:
  • A (Float[Array, 'n n'])

  • B (Float[Array, 'n m'])

  • diagonal_boost (float)

state_moseq.util.psd_inv(A, diagonal_boost=1e-06)[source]

Compute the inverse of a positive semi-definite matrix using Cholesky decomposition. :param A: positive semi-definite matrix :param diagonal_boost: boost to diagonal to ensure positive definiteness

Returns:

inverse of the matrix

Return type:

Ainv

Parameters:
  • A (Float[Array, 'n n'])

  • diagonal_boost (float)

state_moseq.util.cross_sequence_mutual_information(sequence1, sequence2, mask, n_categories, pseudo_count=1e-08)[source]

Compute cross-sequence mutual information.

Parameters:
  • sequence1 (Int[Array, 'n_timesteps']) – first sequence

  • sequence2 (Int[Array, 'n_timesteps']) – second sequence

  • mask (Bool[Array, 'n_timesteps']) – mask for valid timesteps

  • n_categories (int) – number of categories

  • pseudo_count (float) – pseudo count to add to probabilities

Returns:

mutual information

Return type:

mi

state_moseq.util.lagged_mutual_information(sequences, mask, lags, pseudo_count=1e-08)[source]

Compute mutual information at a range of lags for real data and equivalent Markov chains.

Parameters:
  • sequences (Int[Array, 'n_sequences n_timesteps']) – sequences from which to compute mutual information

  • mask (Bool[Array, 'n_sequences n_timesteps']) – mask indicating valid timesteps

  • lags (Int[Array, 'n_lags']) – array of temporal lags

  • pseudo_count (float) – pseudo count to use when computing mutual information

Returns:

mutual information for each sequence at each lag markov_mi: mutual information for equivalent Markov chains at each lag shuff_mi: mutual information across randomly paired sequences at each lag

Return type:

real_mi

state_moseq.util.save_hdf5(filepath, save_dict, datapath=None, overwrite_results=False)[source]

Save a dict of pytrees to an hdf5 file. The leaves of the pytrees must be numpy arrays, scalars, or strings.

Parameters:
  • filepath (str) – Path of the hdf5 file to create.

  • save_dict (Dict[str, PyTree]) – Dictionary where the values are pytrees.

  • datapath (str | None) – Path within hdf5 file to save the data. If None, data are saved at the root.

  • overwrite_results (bool)

Return type:

None

state_moseq.util.load_hdf5(filepath, datapath=None)[source]

Load a dict of pytrees from an hdf5 file.

Parameters:
  • filepath (str) – Path of the hdf5 file to load.

  • datapath (str | None) – Path within hdf5 file to load data from. If None, loads from the root.

Returns:

Dictionary where the values are pytrees.

Return type:

save_dict

state_moseq.util.unbatch(data, keys, bounds)[source]

Invert state_moseq.util.batch()

Parameters:
  • data (Array) – Stack of segmented time-series, shape (n_segs, seg_length, …).

  • keys (list[str] | Array) – Name of the time-series that each segment came from

  • bounds (Int[Array, 'n_segs 2']) – Start and end indices for each segment.

Returns:

Dictionary mapping names to reconstructed time-series.

Return type:

data_dict

state_moseq.util.batch(data_dict, keys=None, seg_length=None, seg_overlap=30)[source]

Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. The data is padded so that the stacked array isn’t ragged. The padding repeats the last frame of each time-series until the end of the segment.

Parameters:
  • data_dict (Dict[str, Array]) – Dictionary of time-series, each of shape (T, …).

  • keys (list[str] | None) – Optional list of keys to control order and inclusion of time-series.

  • seg_length (int | None) – Length of each segment. Defaults to max sequence length.

  • seg_overlap (int) – Overlap between segments in frames.

Returns:

Stacked data array, shape (N, seg_length, …). mask: Binary mask for valid data (1 = valid, 0 = padding), shape (N, seg_length). metadata: Tuple (keys, bounds), identifying sources and segment positions.

Return type:

data

state_moseq.util.get_durations(states_dict)[source]

Get durations of high-level states.

Parameters:

states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of high-level state sequences.

Returns:

Times between high-level state transitions (across all sequences).

Return type:

durations

Examples

>>> states_dict = {
...     'name1': np.array([1, 1, 2, 2, 2, 3]),
...     'name2': np.array([0, 0, 0, 1]),
... }
>>> get_durations(states_dict)
array([2, 3, 1, 3, 1])
state_moseq.util.sample_instances(states_dict, num_instances)[source]

Randomly sample instances of each state.

Parameters:
  • states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of state sequences.

  • num_instances (int) – Number of instances per state.

Returns:

Dictionary mapping state index to instances.

Return type:

sampled_instances

state_moseq.util.get_frequencies(states_dict, num_states=None, runlength=False)[source]

Get frequencies for a batch of high-level state sequences.

Parameters:
  • states_dict (Dict[str, Int[Array, 'n_timesteps']]) – Dictionary of high-level state sequences.

  • num_states (int | None) – Total number of states. If None, inferred from data.

  • runlength (bool) – If True, count only the first timepoint of each run of a state.

Returns:

Frequency of each state across all state sequences

Return type:

frequencies

Examples

>>> states_dict = {
    'name1': np.array([1, 1, 2, 2, 2, 3]),
    'name2': np.array([0, 0, 0, 1])}
>>> get_frequencies(states_dict, runlength=True)
array([0.2, 0.4, 0.2, 0.2])
>>> get_frequencies(states_dict, runlength=False)
array([0.3, 0.3, 0.3, 0.1])
state_moseq.util.get_adjusted_rand(states_dict1, states_dict2, downsample=10)[source]

Compute the adjusted Rand index between two sets of high-level state sequences. :param states_dict1: First dictionary of high-level state sequences. :param states_dict2: Second dictionary of high-level state sequences. :param downsample: Downsampling factor to reduce the length of the sequences.

Returns:

Adjusted Rand index between the two sets of state sequences.

Return type:

adjusted_rand_index

Parameters:
  • states_dict1 (Dict[str, Int[Array, 'n_timesteps']])

  • states_dict2 (Dict[str, Int[Array, 'n_timesteps']])

  • downsample (int)