TrialTree#

TrialTree is a wrapper around xarray.DataTree that stores one xarray.Dataset per trial. Build one from a list of datasets, then access each trial by ID or index:

import numpy as np, xarray as xr, ethograph as eto

# Build: one xr.Dataset per trial
datasets = []
for i in range(1, 4):
    ds = xr.Dataset({"speed": xr.DataArray(np.random.rand(300), dims=["time"],
                                            coords={"time": np.arange(300) / 30.0})})
    ds.attrs["trial"] = i
    ds.attrs["fps"] = 30.0
    datasets.append(ds)

dt = eto.from_datasets(datasets)

# Access by trial ID (label-based, like xr.Dataset.sel)
ds = dt.trial(2)
ds.attrs["trial"]   # 2
ds["speed"]          # the speed DataArray for trial 2

# Access by integer index (0-based, like xr.Dataset.isel)
ds = dt.itrial(0)
ds.attrs["trial"]   # 1

# List all trial IDs
dt.trials            # [1, 2, 3]

# Save / load
dt.save("session.nc")
dt = eto.open("session.nc")

For the xarray.Dataset structure expected inside each trial, see Data Format Requirements.

class ethograph.io.trialtree.TrialTree(*args: Any, **kwargs: Any)[source]#

Creating#

classmethod TrialTree.open(path)[source]#

Load a TrialTree from a NetCDF file.

Auto-discovers .ethograph/alignment.nwb next to the file.

Return type:

TrialTree

classmethod TrialTree.from_datasets(datasets, validate=True)[source]#

Build a TrialTree from a list of xarray Datasets.

Parameters:
  • datasets (list[Dataset]) – Each dataset must have a unique attrs["trial"] key.

  • validate (bool) – Run validation after construction.

Return type:

TrialTree

classmethod TrialTree.from_continuous(ds, epochs)[source]#

Build a TrialTree from a single continuous recording + trial epochs.

Unlike from_datasets() which requires pre-split data, this stores one shared dataset and slices on demand when trial() is called. Time coordinates are shifted to 0 per trial.

Parameters:
  • ds (xr.Dataset) – Full recording dataset. Must have at least one dimension whose name contains "time".

  • epochs (pd.DataFrame | nap.IntervalSet) –

    Trial boundaries. Accepts:

    • pd.DataFrame with columns trial, start_time, stop_time.

    • nap.IntervalSet — trial IDs taken from a "trial" metadata column, or 1, 2, if absent.

Return type:

TrialTree

Examples

>>> import pandas as pd
>>> epochs = pd.DataFrame({
...     "trial": [1, 2, 3],
...     "start_time": [0.0, 60.0, 120.0],
...     "stop_time": [60.0, 120.0, 180.0],
... })
>>> dt = TrialTree.from_continuous(ds, epochs)
>>> dt.trial(2)  # returns 60-120s slice, time shifted to 0

For a single long recording with trial epochs, from_continuous() slices on demand instead of copying data:

import pandas as pd

epochs = pd.DataFrame({
    "trial": [1, 2, 3],
    "start_time": [0.0, 60.0, 120.0],
    "stop_time": [60.0, 120.0, 180.0],
})
dt = eto.from_continuous(ds, epochs)
dt.trial(2)  # returns 60–120 s slice, time shifted to 0
classmethod TrialTree.from_datatree(dt, attrs=None, *, source=None)[source]#

Wrap an existing DataTree as a TrialTree.

Parameters:

source (TrialTree | None) – Original TrialTree to copy nwb_alignment and _source_path from.

Return type:

TrialTree


Accessing trials#

TrialTree.trial(trial)[source]#

Return the dataset for the given trial ID.

Parameters:

trial (int or str) – Trial identifier matching ds.attrs["trial"].

Return type:

Dataset

Examples

>>> import ethograph as eto
>>> dt = eto.open("session.nc")
>>> ds = dt.trial(1)
>>> ds.attrs["trial"]
1
>>> ds["speed"]  # access a feature variable
<xarray.DataArray 'speed' (time: 9000, keypoints: 4)>
TrialTree.itrial(trial_idx)[source]#

Return the dataset at an integer index (0-based).

Parameters:

trial_idx (int) – Zero-based index into the list of trials.

Return type:

Dataset

Examples

>>> import ethograph as eto
>>> dt = eto.open("session.nc")
>>> dt.trials
[1, 2, 3]
>>> ds = dt.itrial(0)   # same as dt.trial(1)
>>> ds.attrs["trial"]
1
>>> ds = dt.itrial(2)   # same as dt.trial(3)
>>> ds.attrs["trial"]
3
property TrialTree.trials: list[int | str]#

List of trial identifiers.

TrialTree.get_all_trials()[source]#

Return a dict mapping trial ID to Dataset for all trials.

Return type:

dict[int, Dataset]

TrialTree.get_common_attrs()[source]#

Return attributes that are identical across all trials.

Return type:

dict[str, Any]

TrialTree.get_trial_metadata(trial)[source]#

Return condition metadata for a single trial as a dict.

Return type:

dict


Iterating over trials#

for trial_id, ds in dt.trial_items():
    print(f"Trial {trial_id}: {len(ds.time)} timepoints")

# Apply a function to every trial, returning a new TrialTree
dt_smoothed = dt.map_trials(lambda ds: smooth(ds))
TrialTree.trial_items()[source]#

Iterate over (trial_id, dataset) pairs for all trial nodes.

TrialTree.map_trials(func)[source]#

Apply func to every trial dataset and return a new TrialTree.

For continuous trees, materialises sliced datasets into a standard per-node TrialTree so func results can be stored.

Return type:

TrialTree


Modifying trials#

In-place mutations work directly through trial() because the returned dataset shares its underlying data with the tree:

dt.trial(1).attrs["human_verified"] = True
dt.trial(1)["speed"].values[:10] = 0.0

Structural changes (adding/removing variables) require update_trial():

dt.update_trial(1, lambda ds: ds.assign(
    smoothed_speed=ds["speed"].rolling(time=5).mean()
))
TrialTree.update_trial(trial, func)[source]#

Read-modify-write a single trial’s dataset.

Not supported for continuous trees — call materialise() first if you need per-trial mutation.

Return type:

None


Filtering#

dt_tone_a = dt.filter_by_attr("stimulus", "tone_A")
TrialTree.filter_by_attr(attr_name, attr_value)[source]#

Return a new TrialTree containing only trials that match an attribute.

Checks the metadata table first; falls back to ds.attrs.

Return type:

TrialTree


Saving#

dt.save("path/to/session.nc")
dt.save()  # overwrite the file it was loaded from

When saving to a new directory, the NWB alignment file is automatically copied alongside the .nc.

TrialTree.save(path=None)[source]#

Write the TrialTree to a NetCDF file.

Continuous trees are materialised into per-trial nodes before saving so the file can be re-opened with open().

Uses an atomic write (temp file then rename) to avoid partial writes. If an alignment NWB exists and the save directory differs from where the NWB lives, a copy is placed in <save_dir>/.ethograph/ so that open() can discover it.

Return type:

None


See also#