Source code for ethograph.labels.ml

"""Dense (array-based) label operations for ML pipelines.

This module provides tools for converting between the interval-based label
format (used by the GUI and TSV storage) and dense integer arrays (used by
ML models).  It also contains post-processing operations commonly applied
to model predictions before evaluation or storage.

Typical ML workflow
-------------------
1. **Load labels from TSV** → ``pd.DataFrame`` with ``onset_s``, ``offset_s``,
   ``labels``, ``individual`` (plus ``n_samples`` per-trial metadata).
2. **Convert to dense** → ``intervals_to_dense(df, sample_rate, individuals, n_samples)``
   gives an ``(n_samples, n_individuals)`` int8 array ready for training.
3. **Run model** → get a dense prediction array of shape ``(T,)`` or ``(T, n_classes)``.
4. **Post-process** → ``purge_small_blocks`` → ``stitch_gaps`` → ``fix_endings``.
5. **Convert back** → ``dense_to_intervals(pred, individuals, sample_rate=sr)``
   gives an intervals DataFrame for storage or evaluation.

The ``n_samples`` value stored in the TSV file (per-trial metadata) tells you
exactly how long the dense array should be — you only need to additionally
know the ``sample_rate`` to drive the conversion.
"""

from __future__ import annotations

import numpy as np
import pandas as pd

from ethograph.labels.intervals import _rows_to_df, states_only


# ── Primitives ───────────────────────────────────────────────────────────

[docs] def find_blocks(mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Find contiguous True blocks in a boolean array. Parameters ---------- mask : np.ndarray Boolean array. Returns ------- starts : np.ndarray Start indices of True blocks. ends : np.ndarray End indices (inclusive) of True blocks. Examples -------- >>> import numpy as np >>> mask = np.array([False, True, True, False, True]) >>> starts, ends = find_blocks(mask) >>> starts array([1, 4]) >>> ends array([2, 4]) """ padded = np.concatenate(([0], mask.astype(int), [0])) diff = np.diff(padded) starts = np.where(diff == 1)[0] ends = np.where(diff == -1)[0] - 1 return starts, ends
def _get_segments(col, bg_class=0): """Find contiguous labeled segments in a 1-D array. Example: ``[0,1,1,1,0,2,2]`` → ``[(1,1,4), (2,5,7)]`` Each tuple is ``(label, start_index, end_index_exclusive)``. """ padded = np.concatenate([[-1], col, [-1]]) change_indices = np.nonzero(padded[:-1] != padded[1:])[0] segments = [] for i in range(len(change_indices) - 1): start = change_indices[i] end = change_indices[i + 1] label = int(col[start]) if label != bg_class: segments.append((label, start, end)) return segments def _get_labels_start_end_times(col, time_coord, individual, bg_class=0): """Convert segments to time intervals (inclusive end).""" segments = _get_segments(col, bg_class) return [{ "onset_s": float(time_coord[start]), "offset_s": float(time_coord[end - 1]), "labels": label, "individual": individual, } for label, start, end in segments] # ── Interval ↔ Dense conversion ─────────────────────────────────────────
[docs] def dense_to_intervals( dense_array: np.ndarray, individuals: list[str], *, sample_rate: float | None = None, time_coord: np.ndarray | None = None, ) -> pd.DataFrame: """Convert a dense label array to an intervals DataFrame. Provide either *sample_rate* (uniform spacing starting at t = 0) or an explicit *time_coord* array. Parameters ---------- dense_array : np.ndarray Shape ``(n_samples,)`` for a single individual, or ``(n_samples, n_individuals)`` for multiple. individuals : list[str] Individual identifiers — length must match the second axis. sample_rate : float, optional Sampling rate in Hz. Timestamps are computed as ``np.arange(n_samples) / sample_rate``. time_coord : np.ndarray, optional Explicit time array of length ``n_samples``. Use this when timestamps are non-uniform or do not start at zero. Returns ------- pd.DataFrame Intervals with columns ``onset_s``, ``offset_s``, ``labels``, ``individual``. ``offset_s`` is **inclusive** (last sample of the segment). Raises ------ ValueError If neither *sample_rate* nor *time_coord* is given, or if the number of individuals does not match the array width. Examples -------- Convert a 1-D dense array at 10 Hz: >>> import numpy as np >>> from ethograph.labels.ml import dense_to_intervals >>> labels = np.array([0, 1, 1, 1, 0, 2, 2]) >>> df = dense_to_intervals(labels, ["crow_A"], sample_rate=10.0) >>> df[["onset_s", "offset_s", "labels"]].values.tolist() [[0.1, 0.3, 1], [0.5, 0.6, 2]] With explicit timestamps: >>> times = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) >>> df = dense_to_intervals(labels, ["crow_A"], time_coord=times) >>> df["onset_s"].tolist() [0.1, 0.5] """ dense_array = np.asarray(dense_array) if sample_rate is None and time_coord is None: raise ValueError("Provide either sample_rate or time_coord") if time_coord is None: time_coord = np.arange(dense_array.shape[0]) / sample_rate else: time_coord = np.asarray(time_coord) if dense_array.ndim == 1: dense_array = dense_array[:, np.newaxis] if dense_array.shape[1] != len(individuals): raise ValueError( f"dense_array has {dense_array.shape[1]} columns but " f"{len(individuals)} individuals given" ) rows: list[dict] = [] for ind_idx, ind_name in enumerate(individuals): col = dense_array[:, ind_idx] rows.extend(_get_labels_start_end_times(col, time_coord, str(ind_name))) return _rows_to_df(rows)
[docs] def intervals_to_dense( df: pd.DataFrame, sample_rate: float, individuals: list[str], n_samples: int, ) -> np.ndarray: """Convert an intervals DataFrame to a dense label array. Each interval is mapped onto the nearest sample indices using ``round(time * sample_rate)``. Overlapping intervals for the same individual are resolved by last-write-wins. Parameters ---------- df : pd.DataFrame Intervals DataFrame with columns ``onset_s``, ``offset_s``, ``labels``, ``individual``. sample_rate : float Sampling rate in Hz (e.g. 30.0 for 30 fps video features). individuals : list[str] Individual identifiers. The output column order matches this list. n_samples : int Number of output time steps. Typically available as per-trial ``n_samples`` metadata in the TSV file. Returns ------- np.ndarray Dense label array of shape ``(n_samples, len(individuals))``, dtype ``int8``. Background (unlabeled) time steps are 0. Examples -------- >>> import pandas as pd >>> from ethograph.labels.ml import intervals_to_dense >>> df = pd.DataFrame({ ... "onset_s": [0.1, 0.5], "offset_s": [0.3, 0.6], ... "labels": [1, 2], "individual": ["A", "A"], ... }) >>> dense = intervals_to_dense(df, sample_rate=10.0, individuals=["A"], n_samples=7) >>> dense[:, 0].tolist() [0, 1, 1, 1, 0, 2, 2] """ dense = np.zeros((n_samples, len(individuals)), dtype=np.int8) ind_to_idx = {name: i for i, name in enumerate(individuals)} df = states_only(df) for _, row in df.iterrows(): ind_idx = ind_to_idx.get(row["individual"]) if ind_idx is None: continue start_idx = int(round(row["onset_s"] * sample_rate)) end_idx = int(round(row["offset_s"] * sample_rate)) start_idx = max(0, start_idx) end_idx = min(n_samples - 1, end_idx) dense[start_idx : end_idx + 1, ind_idx] = int(row["labels"]) return dense
# ── Segment index extraction ────────────────────────────────────────────
[docs] def get_labels_start_end_indices(col, bg_class=0): """Return segment boundaries as sample indices (exclusive end). Useful for slicing dense arrays or computing segment-level metrics. Parameters ---------- col : array-like 1-D dense label array. bg_class : int Background class to ignore (default 0). Returns ------- labels : list[int] Label class for each segment. starts : list[int] Start index (inclusive) of each segment. ends : list[int] End index (**exclusive**) — use ``array[start:end]`` to slice. Examples -------- >>> from ethograph.labels.ml import get_labels_start_end_indices >>> labels, starts, ends = get_labels_start_end_indices([0,1,1,1,0,2,2]) >>> labels [1, 2] >>> starts [1, 5] >>> ends [4, 7] >>> # To extract the first segment from a feature array: >>> # segment_features = features[starts[0]:ends[0], :] """ segments = _get_segments(col, bg_class) labels = [s[0] for s in segments] starts = [s[1] for s in segments] ends = [s[2] for s in segments] return labels, starts, ends
# ── Dense post-processing ────────────────────────────────────────────────
[docs] def stitch_gaps( labels: np.ndarray, max_gap_len: int, skip_labels: set[int] | None = None, ) -> np.ndarray: """Fill small background gaps between same-label segments. Scans the dense label array for short runs of zeros (background) flanked by the same non-zero label on both sides. When the gap is at most *max_gap_len* samples, it is filled with that label. This is typically used **after** model prediction to clean up fragmented outputs where a behaviour is briefly interrupted by a few background frames. Parameters ---------- labels : np.ndarray 1-D dense label array (int), where 0 = background. max_gap_len : int Maximum gap length **in samples** to fill. Gaps longer than this are left untouched. Convert from seconds: ``int(gap_s * sample_rate)``. skip_labels : set[int], optional Labels whose trailing gaps should **never** be filled. For example, ``skip_labels={3}`` means that a gap preceded by label 3 is always kept even if the same label follows. Returns ------- np.ndarray Copy of *labels* with qualifying gaps filled. Examples -------- Basic gap stitching at 30 Hz (fill gaps up to 5 frames): >>> import numpy as np >>> from ethograph.labels.ml import stitch_gaps >>> pred = np.array([1, 1, 0, 1, 1, 0, 0, 0, 2, 2]) >>> stitch_gaps(pred, max_gap_len=2).tolist() [1, 1, 1, 1, 1, 0, 0, 0, 2, 2] The single-frame gap (index 2) is filled because label 1 appears on both sides. The 3-frame gap (indices 5–7) is left alone because it exceeds ``max_gap_len=2``. Using ``skip_labels`` to protect specific transitions: >>> pred = np.array([3, 3, 0, 3, 3]) >>> stitch_gaps(pred, max_gap_len=2, skip_labels={3}).tolist() [3, 3, 0, 3, 3] """ if skip_labels is None: skip_labels = set() stitched = labels.copy() zero_starts, zero_ends = find_blocks(labels == 0) for start, end in zip(zero_starts, zero_ends): gap_len = end - start if gap_len > max_gap_len: continue left_label = labels[start - 1] if start > 0 else 0 right_label = labels[end + 1] if end < len(labels) - 1 else 0 if left_label in skip_labels: continue if left_label != 0 and left_label == right_label: stitched[start:end + 1] = left_label return stitched
[docs] def purge_small_blocks( labels: np.ndarray, min_length: int, label_thresholds: dict[int, int] | None = None, ) -> np.ndarray: """Remove label blocks shorter than a threshold (set to background). Scans the dense label array for contiguous runs of the same non-zero label. If a run is shorter than its threshold, every sample in it is set to 0 (background). This is the dense-array counterpart of :func:`~ethograph.labels.intervals.purge_short_intervals` (which works in seconds on interval DataFrames). Parameters ---------- labels : np.ndarray 1-D dense label array (int), where 0 = background. min_length : int Default minimum block length **in samples**. Blocks shorter than this are zeroed out. Convert from seconds: ``int(min_duration_s * sample_rate)``. label_thresholds : dict[int, int], optional Per-label minimum lengths that override *min_length*. For example, ``{1: 10, 3: 30}`` means label 1 needs ≥10 samples and label 3 needs ≥30 samples; all other labels use *min_length*. Returns ------- np.ndarray Copy of *labels* with short blocks zeroed out. Examples -------- Remove any block shorter than 3 samples: >>> import numpy as np >>> from ethograph.labels.ml import purge_small_blocks >>> pred = np.array([0, 1, 0, 2, 2, 2, 2, 0]) >>> purge_small_blocks(pred, min_length=3).tolist() [0, 0, 0, 2, 2, 2, 2, 0] With per-label thresholds (label 2 needs ≥5 samples): >>> purge_small_blocks(pred, min_length=1, label_thresholds={2: 5}).tolist() [0, 1, 0, 0, 0, 0, 0, 0] Typical pipeline — purge then stitch: >>> pred = np.array([1,1,1, 0, 1, 0, 1,1,1]) >>> cleaned = purge_small_blocks(pred, min_length=2) # remove isolated 1-sample >>> cleaned.tolist() [1, 1, 1, 0, 0, 0, 1, 1, 1] >>> stitch_gaps(cleaned, max_gap_len=4).tolist() [1, 1, 1, 1, 1, 1, 1, 1, 1] """ labels = np.asarray(labels) if len(labels) == 0: return labels.copy() if label_thresholds is None: label_thresholds = {} else: label_thresholds = {int(k): v for k, v in label_thresholds.items()} output = labels.copy() padded = np.concatenate([[-1], labels, [-1]]) change_mask = padded[:-1] != padded[1:] change_indices = np.nonzero(change_mask)[0] for i in range(len(change_indices) - 1): start_idx = change_indices[i] end_idx = change_indices[i + 1] if start_idx >= len(labels): continue label_val = int(labels[start_idx]) if label_val == 0: continue threshold = label_thresholds.get(label_val, min_length) run_length = end_idx - start_idx if run_length < threshold: output[start_idx:end_idx] = 0 return output
[docs] def fix_endings(labels: np.ndarray, changepoints) -> np.ndarray: """Extend label endings by one sample at changepoint boundaries. When a labelled segment ends and the very next sample is a changepoint, the segment is extended by one sample. This accounts for the common off-by-one between predicted segment boundaries and detected changepoints. Parameters ---------- labels : np.ndarray 1-D dense label array (int). changepoints : array-like Either a boolean mask of the same length (True = changepoint), or an array of integer changepoint indices. Returns ------- np.ndarray Copy of *labels* with qualifying endings extended by one sample. Examples -------- >>> import numpy as np >>> from ethograph.labels.ml import fix_endings >>> labels = np.array([0, 1, 1, 0, 0, 2, 2, 0]) >>> cps = np.array([0, 0, 0, 1, 0, 0, 0, 1], dtype=bool) >>> fix_endings(labels, cps).tolist() [0, 1, 1, 1, 0, 2, 2, 2] The segment of label 1 ended at index 2, and index 3 is a changepoint, so label 1 is extended to index 3. Same for label 2 at index 7. """ labels_out = np.array(labels).reshape(-1) changepoints_arr = np.array(changepoints) if changepoints_arr.dtype == bool or ( changepoints_arr.dtype == int and set(np.unique(changepoints_arr)).issubset({0, 1}) ): changepoints_idxs = set(np.where(changepoints_arr)[0]) else: changepoints_idxs = set(changepoints) is_nonzero = labels_out != 0 is_zero_next = np.concatenate([labels_out[1:] == 0, [False]]) segment_ends = np.where(is_nonzero & is_zero_next)[0] for seg_end in segment_ends: if (seg_end + 1) in changepoints_idxs: if labels_out[seg_end] != 0 and labels_out[seg_end + 1] == 0: labels_out[seg_end + 1] = labels_out[seg_end] return labels_out