Skip to content

Commit

Permalink
fix(reader): Pose reader code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ttngu207 committed Jul 2, 2024
1 parent f866d2f commit 099aa5f
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,9 @@ def read(self, file: Path) -> pd.DataFrame:
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)
self.config_file = config_file
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()

# Using bodyparts, assign column names to Harp register values, and read data in default format.
try: # Bonsai.Sleap0.2
Expand Down Expand Up @@ -328,7 +327,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data, config_file_dir)
data = self.class_int2str(data)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
Expand All @@ -351,46 +350,45 @@ def read(self, file: Path) -> pd.DataFrame:
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_class_names(self, file: Path) -> list[str]:
def get_class_names(self) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {file}.") from err
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
return classes

def get_bodyparts(self, file: Path) -> list[str]:
def get_bodyparts(self) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = self.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data
Expand Down

0 comments on commit 099aa5f

Please sign in to comment.