Skip to content

Commit

Permalink
Merge pull request SainsburyWellcomeCentre#381 from ttngu207/datajoin…
Browse files Browse the repository at this point in the history
…t_pipeline

Update SLEAP data ingestion
  • Loading branch information
jkbhagatio committed Jul 23, 2024
2 parents 88e5ecd + 3411fe8 commit 054db5c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 33 deletions.
1 change: 1 addition & 0 deletions aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def fetch_stream(query, drop_pk=True):
df.rename(columns={"timestamps": "time"}, inplace=True)
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
return df


Expand Down
48 changes: 36 additions & 12 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,42 @@ def make(self, key):
patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name")

for patch_key, patch_name in zip(patch_keys, patch_names):
delivered_pellet_df = fetch_stream(
# pellet delivery and patch threshold data
beam_break_df = fetch_stream(
streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction
)[block_start:block_end]
depletion_state_df = fetch_stream(
streams.UndergroundFeederDepletionState & patch_key & chunk_restriction
)[block_start:block_end]
# remove NaNs from threshold column
depletion_state_df = depletion_state_df.dropna(subset=["threshold"])
# identify & remove invalid indices where the time difference is less than 1 second
invalid_indices = np.where(depletion_state_df.index.to_series().diff().dt.total_seconds() < 1)[0]
depletion_state_df = depletion_state_df.drop(depletion_state_df.index[invalid_indices])

# find pellet times associated with each threshold update
# for each threshold, find the time of the next threshold update,
# find the closest beam break after this update time,
# and use this beam break time as the delivery time for the initial threshold
pellet_ts_threshold_df = depletion_state_df.copy()
pellet_ts_threshold_df["pellet_timestamp"] = pd.NaT
for threshold_idx in range(len(pellet_ts_threshold_df) - 1):
if np.isnan(pellet_ts_threshold_df.threshold.iloc[threshold_idx]):
continue
next_threshold_time = pellet_ts_threshold_df.index[threshold_idx + 1]
post_thresh_pellet_ts = beam_break_df.index[beam_break_df.index > next_threshold_time]
next_beam_break = post_thresh_pellet_ts[np.searchsorted(post_thresh_pellet_ts, next_threshold_time)]
pellet_ts_threshold_df.pellet_timestamp.iloc[threshold_idx] = next_beam_break
# remove NaNs from pellet_timestamp column (last row)
pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"])

# wheel encoder data
encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[
block_start:block_end
]
# filter out maintenance period based on logs
pellet_df = filter_out_maintenance_periods(
delivered_pellet_df,
pellet_ts_threshold_df = filter_out_maintenance_periods(
pellet_ts_threshold_df,
maintenance_period,
block_end,
dropna=True,
Expand Down Expand Up @@ -229,22 +253,21 @@ def make(self, key):

patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]

# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate

self.Patch.insert1(
{
**key,
"patch_name": patch_name,
"pellet_count": len(pellet_df),
"pellet_timestamps": pellet_df.index.values,
"pellet_count": len(pellet_ts_threshold_df),
"pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values,
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[
::wheel_downsampling_factor
],
"wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor],
"patch_threshold": depletion_state_df.threshold.values,
"patch_threshold_timestamps": depletion_state_df.index.values,
"patch_threshold": pellet_ts_threshold_df.threshold.values,
"patch_threshold_timestamps": pellet_ts_threshold_df.index.values,
"patch_rate": patch_rate,
"patch_offset": patch_offset,
}
Expand All @@ -267,7 +290,7 @@ def make(self, key):
subject_names = []
for subject_name in set(subject_visits_df.id):
_df = subject_visits_df[subject_visits_df.id == subject_name]
if _df.type[-1] != "Exit":
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

for subject_name in subject_names:
Expand Down Expand Up @@ -454,7 +477,7 @@ def make(self, key):
"dist_to_patch"
].values

# Get closest subject to patch at each pel del timestep
# Get closest subject to patch at each pellet timestep
closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1)
# Get closest subject to patch at each wheel timestep
cum_wheel_dist_subj_df = pd.DataFrame(
Expand All @@ -481,9 +504,10 @@ def make(self, key):
all_subj_patch_pref_dict[patch["patch_name"]][subject_name][
"cum_time"
] = subject_in_patch_cum_time
subj_pellets = closest_subjects_pellet_ts[closest_subjects_pellet_ts == subject_name]

subj_patch_thresh = patch["patch_threshold"][np.searchsorted(patch["patch_threshold_timestamps"], subj_pellets.index.values) - 1]
closest_subj_mask = closest_subjects_pellet_ts == subject_name
subj_pellets = closest_subjects_pellet_ts[closest_subj_mask]
subj_patch_thresh = patch["patch_threshold"][closest_subj_mask]

self.Patch.insert1(
key
Expand Down
10 changes: 6 additions & 4 deletions aeon/dj_pipeline/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,8 @@ def make(self, key):
if not len(pose_data):
raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}")

# get bodyparts and classes
bodyparts = stream_reader.get_bodyparts()
anchor_part = bodyparts[0] # anchor_part is always the first one
class_names = stream_reader.get_class_names()
# get identity names
class_names = np.unique(pose_data.identity)
identity_mapping = {n: i for i, n in enumerate(class_names)}

# ingest parts and classes
Expand All @@ -186,6 +184,10 @@ def make(self, key):
identity_position = pose_data[pose_data["identity"] == identity]
if identity_position.empty:
continue

# get anchor part - always the first one of all the body parts
anchor_part = np.unique(identity_position.part)[0]

for part in set(identity_position.part.values):
part_position = identity_position[identity_position.part == part]
part_entries.append(
Expand Down
36 changes: 19 additions & 17 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
Expand All @@ -295,9 +294,9 @@ def read(self, file: Path) -> pd.DataFrame:
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()
config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)

# Using bodyparts, assign column names to Harp register values, and read data in default format.
try: # Bonsai.Sleap0.2
Expand Down Expand Up @@ -327,7 +326,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data)
data = self.class_int2str(data, config_file)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
Expand All @@ -350,45 +349,48 @@ def read(self, file: Path) -> pd.DataFrame:
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_class_names(self) -> list[str]:
@staticmethod
def get_class_names(config_file: Path) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(self.config_file) as f:
with open(config_file) as f:
config = json.load(f)
if self.config_file.stem == "confmap_config": # SLEAP
if config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
raise KeyError(f"Cannot find class_vectors in {config_file}.") from err
return classes

def get_bodyparts(self) -> list[str]:
@staticmethod
def get_bodyparts(config_file: Path) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(self.config_file) as f:
with open(config_file) as f:
config = json.load(f)
if self.config_file.stem == "confmap_config": # SLEAP
if config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
raise KeyError(f"Cannot find bodyparts in {config_file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
@staticmethod
def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data
Expand Down

0 comments on commit 054db5c

Please sign in to comment.