Skip to content

Commit

Permalink
Simplify caching
Browse files Browse the repository at this point in the history
  • Loading branch information
semenko committed Jan 19, 2024
1 parent c400937 commit 8f93dbf
Showing 1 changed file with 73 additions and 112 deletions.
185 changes: 73 additions & 112 deletions src/bam2tensor/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,56 +62,54 @@ def __init__(
self.verbose = verbose
self.window_size = window_size

# A dict of chromosomes -> index for quick lookups (e.g. "chr1" -> 0)
self.chromosomes_dict: dict[str, int] = {
ch: idx for idx, ch in enumerate(self.expected_chromosomes)
}

self.total_cpg_sites = 0

# Store the CpG sites in a dict per chromosome
self.cpg_sites_dict: dict[str, list[int]] = {}

self.cached_cpg_sites_json = self.genome_name + ".cpg_all_sites.json.gz"

self.windowed_cpg_sites_cache = self.genome_name + ".cpg_windowed_sites.json.gz"
# This is a dict of lists, where but each list contains a tuple of CpG ranges witin a window
# Key: chromosome, e.g. "chr1"
# Value: a list of tuples, e.g. [(0,35), (190,212), (1055,)]
self.windowed_cpg_sites_dict: dict[str, list[tuple]] = {}

self.windowed_cpg_sites_reverse_cache = (
self.genome_name + ".cpg_windowed_sites_reverse.json.gz"
)
# And a reverse dict of dicts where chrom->window_start->[cpgs]
self.windowed_cpg_sites_dict_reverse: dict[str, dict[int, list]] = {}

# TODO: Store the expected chromosomes in some cached object.
self.cache_file = self.genome_name + ".cache.json.gz"

# Check that the expected chromosomes are not empty
if len(self.expected_chromosomes) == 0:
raise ValueError("Expected chromosomes cannot be empty")

# Try to load a cached methylation embedding if it exists
# Try to load a cached embedding if it exists
cache_available = False
if not skip_cache:
try:
self.load_cpg_site_cache()
cache_available = self.load_embedding_cache()
except FileNotFoundError as e:
if self.verbose:
print("Could not load methylation embedding from cache: " + str(e))

# If we don't have a cached methylation embedding, parse the fasta file
if len(self.cpg_sites_dict) == 0:
if not cache_available:
# Generate CpG sites if we don't have a cached embedding
self.parse_fasta_for_cpg_sites()

if not skip_cache:
self.save_cpg_site_cache()
# Now generate windowed CpG sites for efficient querying of .bam files
self.generate_windowed_cpg_sites()

if not skip_cache:
# Save the key & expensive objects to a cache
self.save_embedding_cache()

## Generate objects for efficient lookups
# A dict of chromosomes -> index for quick lookups (e.g. "chr1" -> 0)
self.chromosomes_dict: dict[str, int] = {
ch: idx for idx, ch in enumerate(self.expected_chromosomes)
}

# How many CpG sites are there?
self.total_cpg_sites = sum([len(v) for v in self.cpg_sites_dict.values()])
if self.verbose:
print(f"\t\tTotal CpG sites: {self.total_cpg_sites:,}")

# TODO: Shove these and expected_chromosomes into an object that we can save and cache
# Create a dictionary of chromosome -> CpG site -> index (embedding) for efficient lookup
print(self.cpg_sites_dict.keys())
self.chr_to_cpg_to_embedding_dict = {
Expand All @@ -129,29 +127,36 @@ def __init__(
[cpgs_per_chr[k] for k in self.expected_chromosomes]
)

###########
# Now generate windowed CpG sites for efficient querying of .bam files

# Try to load cached window data if available
if not skip_cache:
try:
self.load_windowed_cpg_site_cache()
except FileNotFoundError as e:
if self.verbose:
print("Could not load windowed embedding from cache: " + str(e))
if verbose:
print(f"Loaded methylation embedding for: {self.genome_name}")

# If we don't have a cached methylation embedding, parse the fasta file
if len(self.windowed_cpg_sites_dict) == 0:
self.generate_windowed_cpg_sites()
def save_embedding_cache(self):
"""Save a cache of expensive objects as our methylation embedding."""

if not skip_cache:
self.save_windowed_cpg_site_cache()
assert len(self.cpg_sites_dict) > 0, "CpG sites dict is empty!"
assert (
len(self.windowed_cpg_sites_dict) > 0
), "Windowed CpG sites dict is empty!"

cache_data = {
"genome_name": self.genome_name,
"fasta_source": self.fasta_source,
"expected_chromosomes": self.expected_chromosomes,
"window_size": self.window_size,
"cpg_sites_dict": self.cpg_sites_dict,
"windowed_cpg_sites_dict": self.windowed_cpg_sites_dict,
"windowed_cpg_sites_dict_reverse": self.windowed_cpg_sites_dict_reverse,
}

if verbose:
print(f"Loaded methylation embedding for: {self.genome_name}")
if self.verbose:
print(f"\tSaving embedding to cache: {self.cache_file}")
with gzip.open(self.cache_file, "wt", compresslevel=3, encoding="utf-8") as f:
json.dump(cache_data, f)
if self.verbose:
print("\tSaved embedding cache cache.")

def load_cpg_site_cache(self):
"""Load a cache of CpG sites from a previously parsed fasta.
def load_embedding_cache(self) -> bool:
"""Load our cached embedding data from a prior run.
Raises
-------
Expand All @@ -162,16 +167,40 @@ def load_cpg_site_cache(self):
if self.verbose:
print(f"\tLoading all CpG sites for: {self.genome_name}")

if os.path.exists(self.cached_cpg_sites_json):
if os.path.exists(self.cache_file):
if self.verbose:
print(f"\t\tReading CpG sites from cache: {self.cached_cpg_sites_json}")
print(f"\t\tReading CpG sites from cache: {self.cache_file}")

# TODO: Add type hinting via TypedDicts?
# e.g. https://stackoverflow.com/questions/51291722/define-a-jsonable-type-using-mypy-pep-526
with gzip.open(self.cached_cpg_sites_json, "rt") as f:
self.cpg_sites_dict = json.load(f)
with gzip.open(self.cache_file, "rt") as f:
self.cache_data = json.load(f)

# Load the cached data
self.genome_name = self.cache_data["genome_name"]
self.fasta_source = self.cache_data["fasta_source"]
self.expected_chromosomes = self.cache_data["expected_chromosomes"]
self.window_size = self.cache_data["window_size"]
self.cpg_sites_dict = self.cache_data["cpg_sites_dict"]
self.windowed_cpg_sites_dict = self.cache_data["windowed_cpg_sites_dict"]

# This is to convert the keys back to integers, since JSON only supports strings as keys
# Note that we want this on the second level keys '1234', not the first level keys 'chr1'
self.windowed_cpg_sites_dict_reverse = {
chrom: {
int(cpg) if cpg.isdigit() else cpg: window
for cpg, window in v.items()
}
for chrom, v in self.cache_data[
"windowed_cpg_sites_dict_reverse"
].items()
}

else:
raise FileNotFoundError("\tNo cache of CpG sites found.")

return True

def parse_fasta_for_cpg_sites(self):
"""Generate a dict of *all* CpG sites across each chromosome in the reference genome.
Expand Down Expand Up @@ -229,50 +258,6 @@ def parse_fasta_for_cpg_sites(self):
if self.verbose:
print(f"\tFound {len(self.cpg_sites_dict)} chromosomes in reference fasta.")

def save_cpg_site_cache(self):
"""Save a cache of CpG sites from a previously parsed fasta."""

assert self.cpg_sites_dict is not None

if self.verbose:
print(f"\tSaving all cpg sites to cache: {self.cached_cpg_sites_json}")
with gzip.open(
self.cached_cpg_sites_json, "wt", compresslevel=3, encoding="utf-8"
) as f:
json.dump(self.cpg_sites_dict, f)
if self.verbose:
print("\tSaved CpG cache.")

def load_windowed_cpg_site_cache(self):
"""Load a cache of windowed CpG sites.
Raises:
-------
FileNotFoundError
If the cached windowed CpG site file cannot be found.
"""

if os.path.exists(self.windowed_cpg_sites_cache) and os.path.exists(
self.windowed_cpg_sites_reverse_cache
):
if self.verbose:
print("\tLoading windowed CpG sites from caches:")
print(f"\t\t{self.windowed_cpg_sites_cache}")
print(f"\t\t{self.windowed_cpg_sites_reverse_cache}")

with gzip.open(self.windowed_cpg_sites_cache, "rt") as f:
self.windowed_cpg_sites_dict = json.load(f)
with gzip.open(self.windowed_cpg_sites_reverse_cache, "rt") as f:
# This wild object_hook is to convert the keys back to integers, since JSON only supports strings as keys
self.windowed_cpg_sites_dict_reverse = json.load(
f,
object_hook=lambda d: {
int(k) if k.isdigit() else k: v for k, v in d.items()
},
)
else:
raise FileNotFoundError("\tNo cache of windowed CpG sites found.")

def generate_windowed_cpg_sites(self):
"""Generate a dict of CpG sites for each chromosome in the reference genome.
Expand Down Expand Up @@ -337,30 +322,6 @@ def generate_windowed_cpg_sites(self):
f"Loaded {len(self.windowed_cpg_sites_dict)} chromosomes from window cache."
)

def save_windowed_cpg_site_cache(self):
"""Save a cache of windowed CpG sites."""

# Save these to .json caches
if self.verbose:
print("\tSaving windowed CpG sites to caches:")
print(f"\t\t{self.windowed_cpg_sites_cache}")
print(f"\t\t{self.windowed_cpg_sites_reverse_cache}")

with gzip.open(
self.windowed_cpg_sites_cache, "wt", compresslevel=3, encoding="utf-8"
) as f:
json.dump(self.windowed_cpg_sites_dict, f)
with gzip.open(
self.windowed_cpg_sites_reverse_cache,
"wt",
compresslevel=3,
encoding="utf-8",
) as f:
json.dump(self.windowed_cpg_sites_dict_reverse, f)

if self.verbose:
print("\tSaved windowed Cpg cache.")

def embedding_to_genomic_position(
self, embedding: Union[int, np.int64]
) -> tuple[str, int]:
Expand Down

0 comments on commit 8f93dbf

Please sign in to comment.