import jax.numpy as jnp
import jax.random as jr
import jax
import optax
import tqdm
import h5py
import numpy as np
from functools import partial
from tensorflow_probability.substrates.jax import distributions as tfd
from jaxtyping import Array, Float, Int, PyTree, Bool
from typing import Tuple, Union, Callable, Dict, Optional, List
from scipy.optimize import linear_sum_assignment
from dynamax.utils.optimize import run_gradient_descent
from jax.scipy.linalg import cho_factor, cho_solve
from sklearn.metrics import adjusted_rand_score
na = jnp.newaxis
[docs]
@jax.vmap
def logits_to_probs(
logits: Float[Array, "n_categories-1"]
) -> Float[Array, "n_categories"]:
"""Convert logits to probabilities."""
logits = jnp.concatenate([logits, jnp.zeros(1)])
return jax.nn.softmax(logits)
[docs]
@jax.vmap
def probs_to_logits(
probs: Float[Array, "n_categories"],
pseudo_count: Float = 1e-8,
) -> Float[Array, "n_categories-1"]:
"""Convert probabilities to logits."""
log_probs = jnp.log(probs + pseudo_count)
return log_probs[:-1] - log_probs[-1]
[docs]
def normal_inverse_gamma_posterior(
seed: Float[Array, "2"],
mean: Float,
sigmasq: Float,
n: Int,
lambda_: Float,
alpha: Float,
beta: Float,
) -> Tuple[Float, Float]:
"""
Sample posterior mean and variance given normal-inverse gamma prior.
Args:
seed: random seed
mean: sample mean
sigmasq: sample variance
n: number of data points
lambda_: strength of prior
alpha: inverse gamma shape parameter
beta: inverse gamma rate parameter
Returns:
mu: posterior mean
sigma: posterior variance
"""
seeds = jr.split(seed, 2)
mean = jnp.nan_to_num(mean)
sigmasq = jnp.nan_to_num(sigmasq)
lambda_n = lambda_ + n
alpha_n = alpha + n / 2
beta_n = beta + 0.5 * n * sigmasq + 0.5 * n * lambda_ * (mean**2) / lambda_n
sigma = sample_inv_gamma(seeds[0], alpha_n, beta_n)
mu = jr.normal(seeds[1]) * jnp.sqrt(sigmasq / lambda_n) + mean
return mu, sigma
[docs]
def center_embedding(n: int) -> Float[Array, "n n-1"]:
"""Generate an orthonormal matrix that embeds R^(n-1) into the space of 0-sum vectors in R^n."""
# using numpy.linalg.svd because jax version crashes on windows
X = jnp.tril(jnp.ones((n, n)), k=-1)[1:]
X = jnp.eye(n)[1:] - X / X.sum(1)[:, na]
X = X / jnp.sqrt((X**2).sum(1))[:, na]
return X.T
[docs]
def lower_dim(arr, axis=0):
"""Lower dimension in specified axis by projecting onto the space of 0-sum vectors."""
arr = jnp.moveaxis(arr, axis, 0)
k, *shape = arr.shape
arr = arr.reshape(k, -1)
arr = center_embedding(k).T @ arr
arr = arr.reshape(k - 1, *shape)
arr = jnp.moveaxis(arr, 0, axis)
return arr
[docs]
def raise_dim(arr, axis=0):
"""Raise dimension in specified axis by embedding into the space of 0-sum vectors."""
arr = jnp.moveaxis(arr, axis, 0)
k, *shape = arr.shape
arr = arr.reshape(k, -1)
arr = center_embedding(k + 1) @ arr
arr = arr.reshape(k + 1, *shape)
arr = jnp.moveaxis(arr, 0, axis)
return arr
def sample_multinomial(
seed: Float[Array, "2"],
n: Int,
p: Float[jnp.ndarray, "n_categories"],
) -> Int[Array, "n_categories"]:
return tfd.Multinomial(n, probs=p).sample(seed=seed)
def sample_gamma(
seed: Float[Array, "2"],
a: Float,
b: Float,
) -> Float:
return jr.gamma(seed, a) / b
def sample_inv_gamma(
seed: Float[Array, "2"],
a: Float,
b: Float,
) -> Float:
return 1.0 / sample_gamma(seed, a, b)
[docs]
@partial(jax.jit, static_argnames=["n_timesteps"])
def simulate_markov_chain(
seed: Float[Array, "2"],
trans_probs: Union[
Float[Array, "n_states n_states"], Float[Array, "n_timesteps n_states n_states"]
],
n_timesteps: Int,
init_probs: Optional[Float[Array, "n_states"]] = None,
) -> Int[Array, "n_timesteps"]:
"""Simulate a state sequence from in Markov chain.
Args:
seed: random seed
trans_probs: transition probabilities between states
n_timesteps: number of timesteps to simulate
init_probs: initial state probabilities. If None, uniform distribution is used.
Returns:
states: simulated state sequence
"""
seeds = jr.split(seed, n_timesteps + 1)
n_states = trans_probs.shape[0]
log_trans_probs = jnp.log(trans_probs)
if init_probs is None:
log_init_probs = jnp.zeros(n_states)
else:
log_init_probs = jnp.log(init_probs)
init_state = jr.categorical(seeds[0], log_init_probs)
if trans_probs.ndim == 2:
def step(state, seed):
next_state = jr.categorical(seed, log_trans_probs[state])
return next_state, next_state
_, states = jax.lax.scan(step, init_state, seeds[1:])
else:
def step(state, args):
seed, logT = args
next_state = jr.categorical(seed, logT[state])
return next_state, next_state
_, states = jax.lax.scan(step, init_state, (seeds[1:], log_trans_probs))
return states
[docs]
def count_transitions(
states: Int[Array, "n_timesteps"],
mask: Int[Array, "n_timesteps"],
n_states: int,
) -> Float[Array, "n_states n_states"]:
"""Count transitions between states.
Args:
states: discrete state sequence
mask: mask of valid observations
n_states: number of discrete states
Returns:
trans_counts: transition counts
"""
trans_counts = (
jnp.zeros((n_states, n_states))
.at[states[:-1], states[1:]]
.add(mask[:-1])
)
return trans_counts
[docs]
def compare_states(
states1: Union[Int[Array, "n_timesteps"], Dict[str, Int[Array, "n_timesteps"]]],
states2: Union[Int[Array, "n_timesteps"], Dict[str, Int[Array, "n_timesteps"]]],
n_states: Int = None,
) -> Tuple[Int[Array, "n_states n_states"], Int[Array, "n_states"], Float]:
"""Compare high-level state sequences.
Args:
states1: first set of state sequences (can be an array or dictionary of sequences)
states2: second set of state sequences (can be an array or dictionary of sequences)
n_states: number of discrete states. If None, inferred from data.
Returns:
confusion_matrix: confusion matrix
optimal_permutation: optimal permutation of first set of states to match second set
accuracy: proportion of timepoints with matching labels (after optimal permutation)
"""
if isinstance(states1, dict):
states1 = _concatenate_stateseqs(states1)
if isinstance(states2, dict):
states2 = _concatenate_stateseqs(states2)
if n_states is None:
n_states = max(states1.max(), states2.max()) + 1
confusion = jnp.zeros((n_states, n_states)).at[states1, states2].add(1)
optimal_perm = linear_sum_assignment(-confusion.T)[1]
accuracy = confusion[optimal_perm, jnp.arange(n_states)].sum() / states2.size
confusion = confusion / confusion.sum(axis=1, keepdims=True)
return confusion, optimal_perm, accuracy
[docs]
def sample_hmc(
seed: Float[Array, "2"],
log_prob_fn: Callable,
init_params: PyTree,
num_leapfrog_steps: Int = 3,
step_size: Float = 0.001,
num_results: Int = 1,
num_burnin_steps: Int = 100,
) -> Tuple[PyTree, PyTree]:
"""Sample using Hamiltonian Monte Carlo."""
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=log_prob_fn,
step_size=step_size,
num_leapfrog_steps=num_leapfrog_steps,
)
params, _, kernel_state = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=init_params,
kernel=hmc_kernel,
seed=seed,
trace_fn=None,
return_final_kernel_results=True,
)
return params, kernel_state
[docs]
def sample_laplace(
seed: Float[Array, "2"],
log_prob_fn: Callable,
init_params: PyTree,
gradient_descent_iters: Int = 200,
gradient_descent_lr: Float = 0.01,
) -> Tuple[PyTree, Float[Array, "gradient_descent_iters"]]:
"""Sample using Laplace approximation. Uses gradient descent to find mode of posterior.
Args:
seed: random seed
log_prob_fn: log probability function
init_params: initial parameters
gradient_descent_iters: number of gradient descent iterations
gradient_descent_lr: gradient descent learning rate
Returns:
params: sampled parameters
losses: loss history
"""
# find the mode of the posterior
mode, _, losses = run_gradient_descent(
lambda x: -log_prob_fn(x),
init_params,
num_mstep_iters=gradient_descent_iters,
optimizer=optax.adam(gradient_descent_lr),
)
# calculate covariance matrix from hessian at mode
mode, unravel_fn = jax.flatten_util.ravel_pytree(mode)
ll_fn = lambda x: log_prob_fn(unravel_fn(x))
hessian_at_mode = jax.hessian(ll_fn)(mode)
covariance_matrix = psd_inv(-hessian_at_mode, diagonal_boost=1e-2)
# sample from laplace approximation
x = jr.multivariate_normal(seed, mean=mode, cov=covariance_matrix)
return unravel_fn(x), losses
[docs]
def symmetrize(A: Float[Array, "n n"]) -> Float[Array, "n n"]:
"""Symmetrize a matrix by averaging it with its transpose."""
return (A + A.swapaxes(-1, -2)) / 2
[docs]
def psd_solve(
A: Float[Array, "n n"],
B: Float[Array, "n m"],
diagonal_boost: float = 1e-6
) -> Float[Array, "n m"]:
"""Solve the linear system Ax = B, where A is a positive semi-definite matrix.
Args:
A: positive semi-definite matrix
B: right-hand side matrix
diagonal_boost: boost to diagonal to ensure positive definiteness
Returns:
x: solution to the linear system
"""
A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1])
L, lower = cho_factor(A, lower=True)
x = cho_solve((L, lower), B)
return x
[docs]
def psd_inv(A : Float[Array, "n n"], diagonal_boost: float = 1e-6) -> Float[Array, "n n"]:
"""Compute the inverse of a positive semi-definite matrix using Cholesky decomposition.
Args:
A: positive semi-definite matrix
diagonal_boost: boost to diagonal to ensure positive definiteness
Returns:
Ainv: inverse of the matrix
"""
Ainv = psd_solve(A, jnp.eye(A.shape[-1]), diagonal_boost=diagonal_boost)
return symmetrize(Ainv)
[docs]
def save_hdf5(
filepath: str,
save_dict: Dict[str, PyTree],
datapath: Optional[str] = None,
overwrite_results: bool = False,
) -> None:
"""Save a dict of pytrees to an hdf5 file. The leaves of the pytrees must
be numpy arrays, scalars, or strings.
Args:
filepath: Path of the hdf5 file to create.
save_dict: Dictionary where the values are pytrees.
datapath: Path within hdf5 file to save the data. If None, data are saved at the root.
"""
with h5py.File(filepath, "a") as f:
if datapath is not None:
_savetree_hdf5(jax.device_get(save_dict), f, datapath)
else:
for k, tree in save_dict.items():
_savetree_hdf5(jax.device_get(tree), f, k)
[docs]
def load_hdf5(
filepath: str,
datapath: Optional[str] = None,
) -> Dict[str, PyTree]:
"""Load a dict of pytrees from an hdf5 file.
Args:
filepath: Path of the hdf5 file to load.
datapath: Path within hdf5 file to load data from. If None, loads from the root.
Returns:
save_dict: Dictionary where the values are pytrees.
"""
with h5py.File(filepath, "r") as f:
if datapath is None:
return {k: _loadtree_hdf5(f[k]) for k in f}
else:
return _loadtree_hdf5(f[datapath])
def _savetree_hdf5(tree: PyTree, group: h5py.Group, name: str) -> None:
"""Recursively save a pytree to an h5 file group."""
if name in group:
del group[name]
if isinstance(tree, np.ndarray):
if tree.dtype.kind == "U":
dt = h5py.special_dtype(vlen=str)
group.create_dataset(name, data=tree.astype(object), dtype=dt)
else:
group.create_dataset(name, data=tree)
elif isinstance(tree, (float, int, str)):
group.create_dataset(name, data=tree)
else:
subgroup = group.create_group(name)
subgroup.attrs["type"] = type(tree).__name__
if isinstance(tree, (tuple, list)):
for k, subtree in enumerate(tree):
_savetree_hdf5(subtree, subgroup, f"arr{k}")
elif isinstance(tree, dict):
for k, subtree in tree.items():
_savetree_hdf5(subtree, subgroup, k)
else:
raise ValueError(f"Unrecognized type {type(tree)}")
def _loadtree_hdf5(leaf: Union[h5py.Dataset, h5py.Group]) -> PyTree:
"""Recursively load a pytree from an h5 file group."""
if isinstance(leaf, h5py.Dataset):
data = np.array(leaf[()])
if h5py.check_dtype(vlen=data.dtype) == str:
data = np.array([item.decode("utf-8") for item in data])
elif data.dtype.kind == "S":
data = data.item().decode("utf-8")
elif data.shape == ():
data = data.item()
return data
else:
leaf_type = leaf.attrs["type"]
values = map(_loadtree_hdf5, leaf.values())
if leaf_type == "dict":
return dict(zip(leaf.keys(), values))
elif leaf_type == "list":
return list(values)
elif leaf_type == "tuple":
return tuple(values)
else:
raise ValueError(f"Unrecognized type {leaf_type}")
[docs]
def unbatch(
data: Array,
keys: Union[list[str], Array],
bounds: Int[Array, "n_segs 2"]
) -> Dict[str, Array]:
"""Invert :py:func:`state_moseq.util.batch`
Args:
data: Stack of segmented time-series, shape (n_segs, seg_length, ...).
keys: Name of the time-series that each segment came from
bounds: Start and end indices for each segment.
Returns:
data_dict: Dictionary mapping names to reconstructed time-series.
"""
data_dict = {}
for key in set(list(keys)):
length = bounds[keys == key, 1].max()
seq = np.zeros((int(length), *data.shape[2:]), dtype=data.dtype)
for (s, e), d in zip(bounds[keys == key], data[keys == key]):
seq[s:e] = d[: e - s]
data_dict[key] = seq
return data_dict
[docs]
def batch(
data_dict: Dict[str, Array],
keys: Optional[list[str]] = None,
seg_length: Optional[int] = None,
seg_overlap: int = 30,
) -> Tuple[Array, Int[Array, "N seg_length"], Tuple[list[str], Int[Array, "N 2"]]]:
"""Stack time-series data of different lengths into a single array for batch
processing, optionally breaking up the data into fixed length segments. The
data is padded so that the stacked array isn't ragged. The padding
repeats the last frame of each time-series until the end of the segment.
Args:
data_dict: Dictionary of time-series, each of shape (T, ...).
keys: Optional list of keys to control order and inclusion of time-series.
seg_length: Length of each segment. Defaults to max sequence length.
seg_overlap: Overlap between segments in frames.
Returns:
data: Stacked data array, shape (N, seg_length, ...).
mask: Binary mask for valid data (1 = valid, 0 = padding), shape (N, seg_length).
metadata: Tuple (keys, bounds), identifying sources and segment positions.
"""
if keys is None:
keys = sorted(data_dict.keys())
Ns = [len(data_dict[key]) for key in keys]
if seg_length is None:
seg_length = np.max(Ns)
stack, mask, keys_out, bounds = [], [], [], []
for key, N in zip(keys, Ns):
for start in range(0, N, seg_length):
arr = data_dict[key]
end = min(start + seg_length + seg_overlap, N)
pad_length = seg_length + seg_overlap - (end - start)
padding = np.repeat(arr[end - 1 : end], pad_length, axis=0)
mask.append(np.hstack([np.ones(end - start), np.zeros(pad_length)]))
stack.append(np.concatenate([arr[start:end], padding], axis=0))
keys_out.append(key)
bounds.append((start, end))
stack = np.stack(stack)
mask = np.stack(mask)
metadata = (np.array(keys_out), np.array(bounds))
return stack, mask, metadata
[docs]
def get_durations(
states_dict: Dict[str, Int[Array, "n_timesteps"]]
) -> Int[Array, "n_durations"]:
"""Get durations of high-level states.
Args:
states_dict: Dictionary of high-level state sequences.
Returns:
durations: Times between high-level state transitions (across all sequences).
Examples:
>>> states_dict = {
... 'name1': np.array([1, 1, 2, 2, 2, 3]),
... 'name2': np.array([0, 0, 0, 1]),
... }
>>> get_durations(states_dict)
array([2, 3, 1, 3, 1])
"""
stateseq_flat = np.hstack(list(states_dict.values()))
stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
changepoints = np.diff(stateseq_padded).nonzero()[0]
return changepoints[1:] - changepoints[:-1]
[docs]
def sample_instances(
states_dict: Dict[str, Int[Array, "n_timesteps"]],
num_instances: int,
) -> Dict[int, List[Tuple[str, int, int]]]:
"""Randomly sample instances of each state.
Args:
states_dict: Dictionary of state sequences.
num_instances: Number of instances per state.
Returns:
sampled_instances: Dictionary mapping state index to instances.
"""
state_ixs = np.unique(np.hstack(list(states_dict.values())))
all_instances = {state_ix: [] for state_ix in state_ixs}
for key, stateseq in states_dict.items():
transitions = np.nonzero(stateseq[1:] != stateseq[:-1])[0] + 1
starts = np.insert(transitions, 0, 0)
ends = np.append(transitions, len(stateseq))
for s, e, state in zip(starts, ends, stateseq[starts]):
all_instances[state].append((key, s, e))
sampled_instances = {}
for state_ix, instances in all_instances.items():
subset = np.random.permutation(len(instances))[:num_instances]
sampled_instances[state_ix] = [instances[i] for i in subset]
return sampled_instances
def _concatenate_stateseqs(states_dict):
"""Concatenate high-level state sequences from a dictionary into a single array."""
return np.hstack([states_dict[key] for key in sorted(states_dict.keys())]).astype(int)
[docs]
def get_frequencies(
states_dict: Dict[str, Int[Array, "n_timesteps"]],
num_states: Optional[int] = None,
runlength: bool = False,
) -> Float[Array, "n_states"]:
"""Get frequencies for a batch of high-level state sequences.
Args:
states_dict: Dictionary of high-level state sequences.
num_states: Total number of states. If None, inferred from data.
runlength: If True, count only the first timepoint of each run of a state.
Returns:
frequencies: Frequency of each state across all state sequences
Examples:
>>> states_dict = {
'name1': np.array([1, 1, 2, 2, 2, 3]),
'name2': np.array([0, 0, 0, 1])}
>>> get_frequencies(states_dict, runlength=True)
array([0.2, 0.4, 0.2, 0.2])
>>> get_frequencies(states_dict, runlength=False)
array([0.3, 0.3, 0.3, 0.1])
"""
stateseq_flat = _concatenate_stateseqs(states_dict)
if num_states is None:
num_states = np.max(stateseq_flat) + 1
if runlength:
state_onsets = np.pad(np.diff(stateseq_flat).nonzero()[0] + 1, (1, 0))
stateseq_flat = stateseq_flat[state_onsets]
counts = np.bincount(stateseq_flat, minlength=num_states)
frequencies = counts / counts.sum()
return frequencies
[docs]
def get_adjusted_rand(
states_dict1: Dict[str, Int[Array, "n_timesteps"]],
states_dict2: Dict[str, Int[Array, "n_timesteps"]],
downsample: int = 10,
) -> float:
"""Compute the adjusted Rand index between two sets of high-level state sequences.
Args:
states_dict1: First dictionary of high-level state sequences.
states_dict2: Second dictionary of high-level state sequences.
downsample: Downsampling factor to reduce the length of the sequences.
Returns:
adjusted_rand_index: Adjusted Rand index between the two sets of state sequences.
"""
seq1 = _concatenate_stateseqs(states_dict1)[::downsample]
seq2 = _concatenate_stateseqs(states_dict2)[::downsample]
return adjusted_rand_score(seq1, seq2)