Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Apr 29, 2024
2 parents 9886a6e + ff99b4f commit 98d7e06
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
12 changes: 10 additions & 2 deletions scdataloader/collator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from .utils import load_genes
from .utils import load_genes, downsample_profile
from torch import Tensor, long

# class SimpleCollator:
Expand All @@ -21,6 +21,8 @@ def __init__(
organism_name="organism_ontology_term_id",
class_names=[],
genelist=[],
downsample=None, # don't use it for training!
save_output=False,
):
"""
This class is responsible for collating data for the scPRINT model. It handles the
Expand Down Expand Up @@ -68,9 +70,10 @@ def __init__(
self.organism_name = organism_name
self.tp_name = tp_name
self.class_names = class_names

self.save_output = save_output
self.start_idx = {}
self.accepted_genes = {}
self.downsample = downsample
self.genedf = load_genes(organisms)
self.to_subset = {}
for organism in set(self.genedf.organism):
Expand Down Expand Up @@ -206,6 +209,11 @@ def __call__(self, batch):
}
if len(dataset) > 0:
ret.update({"dataset": Tensor(dataset).to(long)})
if self.downsample is not None:
ret["x"] = downsample_profile(ret["x"], self.downsample)
if self.save_output:
with open("collator_output.txt", "a") as f:
np.savetxt(f, ret["x"].numpy())
return ret


Expand Down
40 changes: 37 additions & 3 deletions scdataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,48 @@
from scipy.stats import median_abs_deviation
from functools import lru_cache
from collections import Counter
from torch import Tensor
import torch

from typing import Union, List, Optional

from anndata import AnnData


def downsample_profile(mat: Tensor, dropout: float):
"""
This function downsamples the expression profile of a given single cell RNA matrix.
The noise is applied based on the renoise parameter,
the total counts of the matrix, and the number of genes. The function first calculates the noise
threshold (scaler) based on the renoise parameter. It then generates an initial matrix count by
applying a Poisson distribution to a random tensor scaled by the total counts and the number of genes.
The function then models the sampling zeros by applying a Poisson distribution to a random tensor
scaled by the noise threshold, the total counts, and the number of genes. The function also models
the technical zeros by generating a random tensor and comparing it to the noise threshold. The final
matrix count is calculated by subtracting the sampling zeros from the initial matrix count and
multiplying by the technical zeros. The function ensures that the final matrix count is not less
than zero by taking the maximum of the final matrix count and a tensor of zeros. The function
returns the final matrix count.
Args:
mat (torch.Tensor): The input matrix.
dropout (float): The renoise parameter.
Returns:
torch.Tensor: The matrix count after applying noise.
"""
batch = mat.shape[0]
ngenes = mat.shape[1]
dropout = dropout * 1.1
# we model the sampling zeros (dropping 30% of the reads)
res = torch.poisson((mat * (dropout / 2))).int()
# we model the technical zeros (dropping 50% of the genes)
notdrop = (torch.rand((batch, ngenes), device=mat.device) >= (dropout / 2)).int()
mat = (mat - res) * notdrop
return torch.maximum(mat, torch.zeros((1, 1), device=mat.device, dtype=torch.int))


def createFoldersFor(filepath: str):
"""
will recursively create folders if needed until having all the folders required to save the file in this filepath
Expand Down Expand Up @@ -404,9 +440,7 @@ def populate_my_ontology(
ln.save(records, parents=bool(tissues))
bt.Tissue(name="unknown", ontology_id="unknown").save()
# DevelopmentalStage
names = (
bt.DevelopmentalStage.public().df().index if not dev_stages else dev_stages
)
names = bt.DevelopmentalStage.public().df().index if not dev_stages else dev_stages
records = bt.DevelopmentalStage.from_values(names, field="ontology_id")
ln.save(records, parents=bool(dev_stages))
bt.DevelopmentalStage(name="unknown", ontology_id="unknown").save()
Expand Down

0 comments on commit 98d7e06

Please sign in to comment.