TrialTree#
TrialTree is a wrapper around xarray.DataTree that stores
one xarray.Dataset per trial. Build one from a list of datasets,
then access each trial by ID or index:
import numpy as np, xarray as xr, ethograph as eto
# Build: one xr.Dataset per trial
datasets = []
for i in range(1, 4):
ds = xr.Dataset({"speed": xr.DataArray(np.random.rand(300), dims=["time"],
coords={"time": np.arange(300) / 30.0})})
ds.attrs["trial"] = i
ds.attrs["fps"] = 30.0
datasets.append(ds)
dt = eto.from_datasets(datasets)
# Access by trial ID (label-based, like xr.Dataset.sel)
ds = dt.trial(2)
ds.attrs["trial"] # 2
ds["speed"] # the speed DataArray for trial 2
# Access by integer index (0-based, like xr.Dataset.isel)
ds = dt.itrial(0)
ds.attrs["trial"] # 1
# List all trial IDs
dt.trials # [1, 2, 3]
# Save / load
dt.save("session.nc")
dt = eto.open("session.nc")
For the xarray.Dataset structure expected inside each trial, see
Data Format Requirements.
Creating#
- classmethod TrialTree.open(path)[source]#
Load a TrialTree from a NetCDF file.
Auto-discovers
.ethograph/alignment.nwbnext to the file.- Return type:
- classmethod TrialTree.from_datasets(datasets, validate=True)[source]#
Build a TrialTree from a list of xarray Datasets.
- classmethod TrialTree.from_continuous(ds, epochs)[source]#
Build a TrialTree from a single continuous recording + trial epochs.
Unlike
from_datasets()which requires pre-split data, this stores one shared dataset and slices on demand whentrial()is called. Time coordinates are shifted to 0 per trial.- Parameters:
ds (xr.Dataset) – Full recording dataset. Must have at least one dimension whose name contains
"time".epochs (pd.DataFrame | nap.IntervalSet) –
Trial boundaries. Accepts:
pd.DataFramewith columnstrial,start_time,stop_time.nap.IntervalSet— trial IDs taken from a"trial"metadata column, or1, 2, …if absent.
- Return type:
TrialTree
Examples
>>> import pandas as pd >>> epochs = pd.DataFrame({ ... "trial": [1, 2, 3], ... "start_time": [0.0, 60.0, 120.0], ... "stop_time": [60.0, 120.0, 180.0], ... }) >>> dt = TrialTree.from_continuous(ds, epochs) >>> dt.trial(2) # returns 60-120s slice, time shifted to 0
For a single long recording with trial epochs, from_continuous()
slices on demand instead of copying data:
import pandas as pd
epochs = pd.DataFrame({
"trial": [1, 2, 3],
"start_time": [0.0, 60.0, 120.0],
"stop_time": [60.0, 120.0, 180.0],
})
dt = eto.from_continuous(ds, epochs)
dt.trial(2) # returns 60–120 s slice, time shifted to 0
Accessing trials#
- TrialTree.trial(trial)[source]#
Return the dataset for the given trial ID.
Examples
>>> import ethograph as eto >>> dt = eto.open("session.nc") >>> ds = dt.trial(1) >>> ds.attrs["trial"] 1 >>> ds["speed"] # access a feature variable <xarray.DataArray 'speed' (time: 9000, keypoints: 4)>
- TrialTree.itrial(trial_idx)[source]#
Return the dataset at an integer index (0-based).
Examples
>>> import ethograph as eto >>> dt = eto.open("session.nc") >>> dt.trials [1, 2, 3] >>> ds = dt.itrial(0) # same as dt.trial(1) >>> ds.attrs["trial"] 1 >>> ds = dt.itrial(2) # same as dt.trial(3) >>> ds.attrs["trial"] 3
Iterating over trials#
for trial_id, ds in dt.trial_items():
print(f"Trial {trial_id}: {len(ds.time)} timepoints")
# Apply a function to every trial, returning a new TrialTree
dt_smoothed = dt.map_trials(lambda ds: smooth(ds))
Modifying trials#
In-place mutations work directly through trial() because
the returned dataset shares its underlying data with the tree:
dt.trial(1).attrs["human_verified"] = True
dt.trial(1)["speed"].values[:10] = 0.0
Structural changes (adding/removing variables) require
update_trial():
dt.update_trial(1, lambda ds: ds.assign(
smoothed_speed=ds["speed"].rolling(time=5).mean()
))
Filtering#
dt_tone_a = dt.filter_by_attr("stimulus", "tone_A")
Saving#
dt.save("path/to/session.nc")
dt.save() # overwrite the file it was loaded from
When saving to a new directory, the NWB alignment file is automatically
copied alongside the .nc.
- TrialTree.save(path=None)[source]#
Write the TrialTree to a NetCDF file.
Continuous trees are materialised into per-trial nodes before saving so the file can be re-opened with
open().Uses an atomic write (temp file then rename) to avoid partial writes. If an alignment NWB exists and the save directory differs from where the NWB lives, a copy is placed in
<save_dir>/.ethograph/so thatopen()can discover it.- Return type:
See also#
Dataset — dataset builders (
downsample_trialtree(),add_changepoints_to_ds(),add_angle_rgb_to_ds()).Pynapple IO — pynapple loading (
load_nap_data(),add_changepoints_to_nap()) and NWB-import probes.NWB alignment — trial timing, media paths, stream offsets via
dt.nwb_alignment.Labels — TSV label sidecar format and helpers.
Changepoints — detection, merging, time extraction, and label correction.