PredictionsStore#

class ethograph.labels.predictions.PredictionsStore(folder)[source]#

Bases: object

Lazy per-trial loader for a predictions folder.

Scans the folder at construction time (fast — filesystem only, no file reads). Individual trial data is loaded on demand via get_confidence().

Supports .npy (memory-mapped when shape is 1-D) and .pkl/.pickle formats. Additional formats can be added to load_prediction_file.

Parameters:

folder (str or Path) – Folder containing per-trial prediction files.

Example

store = PredictionsStore("predictions_cetnet_20260330/uncorr")
confidence = store.get_confidence(trial=5, dt=dt)
labels_df, levels = store.load_all(dt, individual="Poppy", threshold=0.75)

Methods

get_confidence(trial, dt)

Load and return the confidence array for one trial.

get_file(trial, trial_list)

Return the prediction file path for a trial, or None.

load_all(dt, individual[, ...])

Load all trials — convert to intervals and compute confidence levels.

get_confidence(trial, dt)[source]#

Load and return the confidence array for one trial.

For .npy probability files the array is memory-mapped; for .pkl files the full file is read (typically ~150 KB — a few milliseconds). The returned array is not cached — call again to re-load if needed.

Return type:

np.ndarray | None

get_file(trial, trial_list)[source]#

Return the prediction file path for a trial, or None.

Return type:

Path | None

load_all(dt, individual, confidence_threshold=0.75, segment_confidence_threshold=0.6)[source]#

Load all trials — convert to intervals and compute confidence levels.

Confidence arrays are computed in one pass then discarded; only the per-trial high/low classification is kept. The same two-condition criterion used in the confidence PDF is applied: a trial is “low” if its overall mean confidence < confidence_threshold OR any labeled segment’s mean confidence < segment_confidence_threshold.

Parameters:
  • dt (TrialTree)

  • individual (str)

  • confidence_threshold (float) – Frame-level threshold; overall trial mean below this → “low”.

  • segment_confidence_threshold (float) – Segment-level threshold; any segment mean below this → “low”.

Return type:

tuple[DataFrame, dict[int | str, str]]

Returns:

  • all_labels_df (pd.DataFrame)

  • confidence_levels (dict) – {trial: "low" | "high"}