Skip to content

Commit

Permalink
Merge branch 'dev_Bonsai-Sleap0.3_reader' into datajoint_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Jul 16, 2024
2 parents caca2e6 + d837515 commit 9299ed4
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
Expand All @@ -295,9 +294,9 @@ def read(self, file: Path) -> pd.DataFrame:
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()
config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)

# Using bodyparts, assign column names to Harp register values, and read data in default format.
try: # Bonsai.Sleap0.2
Expand Down Expand Up @@ -327,7 +326,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data)
data = self.class_int2str(data, config_file)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
Expand All @@ -350,45 +349,48 @@ def read(self, file: Path) -> pd.DataFrame:
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_class_names(self) -> list[str]:
@staticmethod
def get_class_names(config_file: Path) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(self.config_file) as f:
with open(config_file) as f:
config = json.load(f)
if self.config_file.stem == "confmap_config": # SLEAP
if config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
raise KeyError(f"Cannot find class_vectors in {config_file}.") from err
return classes

def get_bodyparts(self) -> list[str]:
@staticmethod
def get_bodyparts(config_file: Path) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(self.config_file) as f:
with open(config_file) as f:
config = json.load(f)
if self.config_file.stem == "confmap_config": # SLEAP
if config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
raise KeyError(f"Cannot find bodyparts in {config_file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
@staticmethod
def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data
Expand Down

0 comments on commit 9299ed4

Please sign in to comment.