Skip to content

Commit

Permalink
tiny debug with the update
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Apr 6, 2024
1 parent 63abf08 commit dd6f1f7
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions scdataloader/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,13 @@ def __post_init__(self):
self.class_topred[clss] = self.mapped_dataset.get_merged_categories(
clss
)
update = {}
c = 0
for k, v in self.mapped_dataset.encoders[clss].items():
if k == self.mapped_dataset.unknown_label:
update.update({k: v})
c += 1
self.class_topred[clss] -= set([k])
else:
update.update({k: v - c})
self.mapped_dataset.encoders[clss] = update
if (
self.mapped_dataset.unknown_label
in self.mapped_dataset.encoders[clss].keys()
):
self.class_topred[clss] -= set(
[self.mapped_dataset.unknown_label]
)

if self.genedf is None:
self.genedf = load_genes(self.organisms)
Expand Down Expand Up @@ -238,6 +235,7 @@ def define_hierarchies(self, labels: list[str]):
if label in self.clss_to_pred:
# if we have added new labels, we need to update the encoder with them too.
mlength = len(self.mapped_dataset.encoders[label])

mlength -= (
1
if self.mapped_dataset.unknown_label
Expand All @@ -253,7 +251,6 @@ def define_hierarchies(self, labels: list[str]):

self.class_topred[label] = lclass
c = 0
d = 0
update = {}
mlength = len(lclass)
# import pdb
Expand All @@ -271,10 +268,9 @@ def define_hierarchies(self, labels: list[str]):
c += 1
elif k == self.mapped_dataset.unknown_label:
update.update({k: v})
d += 1
self.class_topred[label] -= set([k])
else:
update.update({k: (v - c) - d})
update.update({k: v - c})
self.mapped_dataset.encoders[label] = update


Expand Down

0 comments on commit dd6f1f7

Please sign in to comment.