diff --git a/pyproject.toml b/pyproject.toml index 3dd530e..e6a1557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,10 @@ dependencies = [ "scipy>=1.10.1", "pyarrow>=13.0.0", "pyreadr>=0.4.9", - "dynamicTreeCut>=0.1.1" + "dynamicTreeCut>=0.1.1", + "plotly>=5.18.0", + "numba==0.58.1", + "biopython==1.81" ] description = "Analytical framework for BS-seq data comparison and visualization" readme = "README.md" diff --git a/src/bismarkplot/BismarkPlot.py b/src/bismarkplot/BismarkPlot.py index 3676ab5..eca3b32 100644 --- a/src/bismarkplot/BismarkPlot.py +++ b/src/bismarkplot/BismarkPlot.py @@ -17,13 +17,14 @@ from pandas import DataFrame as pdDataFrame from pyreadr import write_rds -from src.bismarkplot.base import BismarkBase, BismarkFilesBase -from src.bismarkplot.clusters import Clustering -from src.bismarkplot.utils import remove_extension, approx_batch_num, prepare_labels, interval +from .base import BismarkBase, BismarkFilesBase +from .clusters import Clustering +from .utils import remove_extension, approx_batch_num, prepare_labels, interval import plotly.graph_objects as go import plotly.express as px - +from pathlib import Path +import pyarrow.parquet as pq from collections import OrderedDict @@ -70,38 +71,50 @@ def from_file( gene_windows=gene_windows, downstream_windows=downstream_windows) - @staticmethod - def __read_bismark_batches( + @classmethod + def from_parquet( + cls, file: str, genome: pl.DataFrame, - upstream_windows: int = 500, + upstream_windows: int = 0, gene_windows: int = 2000, - downstream_windows: int = 500, - batch_size: int = 10 ** 7, - cpu: int = cpu_count(), + downstream_windows: int = 0, sumfunc: str = "mean" - ) -> pl.DataFrame: - cpu = cpu if cpu is not None else cpu_count() + ): + """ + Constructor from Bismark coverage2cytosine output. - # enable string cache for categorical comparison - pl.enable_string_cache(True) + :param cpu: How many cores to use. Uses every physical core by default + :param file: Path to bismark genomeWide report + :param genome: polars.Dataframe with gene ranges + :param upstream_windows: Number of windows flank regions to split + :param downstream_windows: Number of windows flank regions to split + :param gene_windows: Number of windows gene regions to split + :param batch_size: Number of rows to read by one CPU core + """ + if upstream_windows < 1: + upstream_windows = 0 + if downstream_windows < 1: + downstream_windows = 0 + if gene_windows < 1: + gene_windows = 0 + + bismark_df = cls.__read_parquet_batches(file, genome, + upstream_windows, gene_windows, downstream_windows, sumfunc) + return cls(bismark_df, + upstream_windows=upstream_windows, + gene_windows=gene_windows, + downstream_windows=downstream_windows) + + @staticmethod + def __process_batch(df: pl.DataFrame, genome: pl.DataFrame, df_columns, up_win, gene_win, down_win, sumfunc): # *** POLARS EXPRESSIONS *** # cast genome columns to type to join GENE_COLUMNS = [ pl.col('strand').cast(pl.Categorical), pl.col('chr').cast(pl.Categorical) ] - # cast report columns to optimized type - DF_COLUMNS = [ - pl.col('position').cast(pl.Int32), - pl.col('chr').cast(pl.Categorical), - pl.col('strand').cast(pl.Categorical), - pl.col('context').cast(pl.Categorical), - # density for CURRENT cytosine - ((pl.col('count_m')) / (pl.col('count_m') + pl.col('count_um'))).alias('density').cast(pl.Float32) - ] - # upstream region position check UP_REGION = pl.col('position') < pl.col('start') # body region position check @@ -109,23 +122,16 @@ def __read_bismark_batches( # downstream region position check DOWN_REGION = (pl.col('position') > pl.col('end')) - UP_FRAGMENT = (( - (pl.col('position') - pl.col('upstream')) / (pl.col('start') - pl.col('upstream')) - ) * upstream_windows).floor() + UP_FRAGMENT = (((pl.col('position') - pl.col('upstream')) / (pl.col('start') - pl.col('upstream')) + ) * up_win).floor() # fragment even for position == end needs to be rounded by floor # so 1e-10 is added (position is always < end) - BODY_FRAGMENT = (( - (pl.col('position') - pl.col('start')) / (pl.col('end') - pl.col('start') + 1e-10) - ) * gene_windows).floor() + upstream_windows + BODY_FRAGMENT = (((pl.col('position') - pl.col('start')) / (pl.col('end') - pl.col('start') + 1e-10) + ) * gene_win).floor() + up_win - DOWN_FRAGMENT = (( - (pl.col('position') - pl.col('end')) / (pl.col('downstream') - pl.col('end') + 1e-10) - ) * downstream_windows).floor() + upstream_windows + gene_windows - - # batch approximation - read_approx = approx_batch_num(file, batch_size) - read_batches = 0 + DOWN_FRAGMENT = (((pl.col('position') - pl.col('end')) / (pl.col('downstream') - pl.col('end') + 1e-10) + ) * down_win).floor() + up_win + gene_win # Firstly BismarkPlot was written so there were only one sum statistic - mean. # Sum and count of densities was calculated for further weighted mean analysis in respect to fragment size @@ -146,6 +152,64 @@ def __read_bismark_batches( else: AGG_EXPR = [pl.sum('density').alias('sum'), pl.count('density').alias('count')] + return ( + df.lazy() + # assign types + # calculate density for each cytosine + .with_columns(df_columns) + # drop redundant columns, because individual cytosine density has already been calculated + # individual counts do not matter because every cytosine is equal + + # .drop(['count_m', 'count_um']) + + # sort by position for joining + .sort(['chr', 'strand', 'position']) + # join with nearest + .join_asof( + genome.lazy().with_columns(GENE_COLUMNS), + left_on='position', right_on='upstream', by=['chr', 'strand'] + ) + # limit by end of region + .filter(pl.col('position') <= pl.col('downstream')) + # calculate fragment ids + .with_columns([ + pl.when(UP_REGION).then(UP_FRAGMENT) + .when(BODY_REGION).then(BODY_FRAGMENT) + .when(DOWN_REGION).then(DOWN_FRAGMENT) + .cast(pl.Int32).alias('fragment'), + pl.concat_str( + pl.col("chr"), + (pl.concat_str(pl.col("start"), pl.col("end"), separator="-")), + separator=":").alias("gene").cast(pl.Categorical), + pl.col('id').cast(pl.Categorical) + ]) + # gather fragment stats + .groupby(by=['chr', 'strand', 'gene', 'context', 'id', 'fragment']) + .agg(AGG_EXPR) + .drop_nulls(subset=['sum']) + ).collect() + + @classmethod + def __read_bismark_batches( + cls, + file: str, + genome: pl.DataFrame, + up_win: int = 500, + gene_win: int = 2000, + down_win: int = 500, + batch_size: int = 10 ** 7, + cpu: int = cpu_count(), + sumfunc: str = "mean" + ) -> pl.DataFrame: + cpu = cpu if cpu is not None else cpu_count() + + # enable string cache for categorical comparison + pl.enable_string_cache(True) + + # batch approximation + read_approx = approx_batch_num(file, batch_size) + read_batches = 0 + # *** READING START *** # output dataframe total = None @@ -161,51 +225,23 @@ def __read_bismark_batches( ) batches = bismark.next_batches(cpu) - def process_batch(df: pl.DataFrame): - return ( - df.lazy() - # filter empty rows - .filter((pl.col('count_m') + pl.col('count_um') != 0)) - # assign types - # calculate density for each cytosine - .with_columns(DF_COLUMNS) - # drop redundant columns, because individual cytosine density has already been calculated - # individual counts do not matter because every cytosine is equal - .drop(['count_m', 'count_um']) - # sort by position for joining - .sort(['chr', 'strand', 'position']) - # join with nearest - .join_asof( - genome.lazy().with_columns(GENE_COLUMNS), - left_on='position', right_on='upstream', by=['chr', 'strand'] - ) - # limit by end of region - .filter(pl.col('position') <= pl.col('downstream')) - # calculate fragment ids - .with_columns([ - pl.when(UP_REGION).then(UP_FRAGMENT) - .when(BODY_REGION).then(BODY_FRAGMENT) - .when(DOWN_REGION).then(DOWN_FRAGMENT) - .cast(pl.Int32).alias('fragment'), - pl.concat_str( - pl.col("chr"), - (pl.concat_str(pl.col("start"), pl.col("end"), separator="-")), - separator=":").alias("gene").cast(pl.Categorical), - pl.col('id').cast(pl.Categorical) - ]) - # gather fragment stats - .groupby(by=['chr', 'strand', 'gene', 'context', 'id', 'fragment']) - .agg(AGG_EXPR) - .drop_nulls(subset=['sum']) - ).collect() + df_columns = [ + pl.col('position').cast(pl.Int32), + pl.col('chr').cast(pl.Categorical), + pl.col('strand').cast(pl.Categorical), + pl.col('context').cast(pl.Categorical), + # density for CURRENT cytosine + ((pl.col('count_m')) / (pl.col('count_m') + pl.col('count_um'))).alias('density').cast(pl.Float32) + ] print(f"Reading from {file}") while batches: for df in batches: - df = process_batch(df) + df = df.filter((pl.col('count_m') + pl.col('count_um') != 0)) + df = cls.__process_batch(df, genome, df_columns, up_win, gene_win, down_win, sumfunc) if total is None and len(df) == 0: raise Exception( - "Error reading Bismark file. Check format or genome. No joins on first batch.") + "Error reading cytosine file. Check format or genome. No joins on first batch.") elif total is None: total = df else: @@ -219,6 +255,61 @@ def process_batch(df: pl.DataFrame): print("DONE") return total + @classmethod + def __read_parquet_batches( + cls, + file: str, + genome: pl.DataFrame, + up_win: int = 500, + gene_win: int = 2000, + down_win: int = 500, + sumfunc: str = "mean" + ) -> pl.DataFrame: + file = Path(file) + + # enable string cache for categorical comparison + pl.enable_string_cache(True) + + # *** READING START *** + # output dataframe + total = None + # initialize batched reader + pq_file = pq.ParquetFile(file.absolute()) + + # batch approximation + num_row_groups = pq_file.metadata.num_row_groups + + df_columns = [ + pl.col('position').cast(pl.Int32), + pl.col('chr').cast(pl.Categorical), + pl.col('strand').cast(pl.Categorical), + pl.col('context').cast(pl.Categorical), + # density for CURRENT cytosine + (pl.col('count_m') / pl.col('count_total')).alias('density').cast(pl.Float32) + ] + + print(f"Reading from {file}") + + for i in range(num_row_groups): + df = pl.from_arrow(pq_file.read_row_group(i)) + df = df.filter(pl.col("count_total") != 0) + + df = cls.__process_batch(df, genome, df_columns, up_win, gene_win, down_win, sumfunc) + if total is None and len(df) == 0: + raise Exception( + "Error reading cytosine file. Check format or genome. No joins on first batch.") + elif total is None: + total = df + else: + total = total.extend(df) + + print( + f"\tRead {i}/{num_row_groups} batch | Total size - {round(total.estimated_size('mb'), 1)}Mb RAM", + end="\r") + + print("DONE") + return total + def filter(self, context: str = None, strand: str = None, chr: str = None): """ :param context: Methylation context (CG, CHG, CHH) to filter (only one). @@ -378,7 +469,7 @@ def __calculate_plot_data(df: pl.DataFrame, stat): def __strand_reverse(self, df: pl.DataFrame): if self.strand == '-': - max_fragment = self.plot_data["fragment"].max() + max_fragment = df["fragment"].max() return df.with_columns((max_fragment - pl.col("fragment")).alias("fragment")) else: return df @@ -425,13 +516,16 @@ def __add_flank_lines(self, axes: Axes, major_labels: list, minor_labels: list, if self.upstream_windows < 1: labels["up_mid"], labels["body_start"] = [""] * 2 - x_ticks = self.tick_positions - x_labels = [labels[key] for key in x_ticks.keys()] + ticks = self.tick_positions + + names = list(ticks.keys()) + x_ticks = [ticks[key] for key in names] + x_labels = [labels[key] for key in names] axes.set_xticks(x_ticks, labels=x_labels) if show_border: - for tick in [x_ticks["body_start"], x_ticks["body_end"]]: + for tick in [ticks["body_start"], ticks["body_end"]]: axes.axvline(x=tick, linestyle='--', color='k', alpha=.3) return axes @@ -449,8 +543,10 @@ def __add_flank_lines_plotly(self, figure: go.Figure, major_labels: list, minor_ labels["up_mid"], labels["body_start"] = [""] * 2 ticks = self.tick_positions - x_ticks = list(ticks.keys()) - x_labels = [labels[key] for key in x_ticks] + + names = list(ticks.keys()) + x_ticks = [ticks[key] for key in names] + x_labels = [labels[key] for key in names] figure.update_layout( xaxis=dict( @@ -684,13 +780,16 @@ def __add_flank_lines(self, axes: Axes, major_labels: list, minor_labels: list, if self.upstream_windows < 1: labels["up_mid"], labels["body_start"] = [""] * 2 - x_ticks = self.tick_positions - x_labels = [labels[key] for key in x_ticks.keys()] + ticks = self.tick_positions + + names = list(ticks.keys()) + x_ticks = [ticks[key] for key in names] + x_labels = [labels[key] for key in names] axes.set_xticks(x_ticks, labels=x_labels) if show_border: - for tick in [x_ticks["body_start"], x_ticks["body_end"]]: + for tick in [ticks["body_start"], ticks["body_end"]]: axes.axvline(x=tick, linestyle='--', color='k', alpha=.3) return axes @@ -707,8 +806,11 @@ def __add_flank_lines_plotly(self, figure: go.Figure, major_labels: list, minor_ if self.upstream_windows < 1: labels["up_mid"], labels["body_start"] = [""] * 2 - x_ticks = self.tick_positions - x_labels = [labels[key] for key in x_ticks.keys()] + ticks = self.tick_positions + + names = list(ticks.keys()) + x_ticks = [ticks[key] for key in names] + x_labels = [labels[key] for key in names] figure.update_layout( xaxis=dict( @@ -718,7 +820,7 @@ def __add_flank_lines_plotly(self, figure: go.Figure, major_labels: list, minor_ ) if show_border: - for tick in [x_ticks["body_start"], x_ticks["body_end"]]: + for tick in [ticks["body_start"], ticks["body_end"]]: figure.add_vline(x=tick, line_dash="dash", line_color="rgba(0,0,0,0.2)") return figure @@ -1095,17 +1197,20 @@ def __add_flank_lines_plotly(self, figure: go.Figure, major_labels: list, minor_ if self.samples[0].upstream_windows < 1: labels["up_mid"], labels["body_start"] = [""] * 2 - x_ticks = self.samples[0].tick_positions - x_labels = [labels[key] for key in x_ticks.keys()] + ticks = self.samples[0].tick_positions + + names = list(ticks.keys()) + x_ticks = [ticks[key] for key in names] + x_labels = [labels[key] for key in names] figure.for_each_xaxis(lambda x: x.update( tickmode='array', - tickvals=list(x_ticks.values()), + tickvals=x_ticks, ticktext=x_labels) ) if show_border: - for tick in [x_ticks["body_start"], x_ticks["body_end"]]: + for tick in [ticks["body_start"], ticks["body_end"]]: figure.add_vline(x=tick, line_dash="dash", line_color="rgba(0,0,0,0.2)") return figure @@ -1173,14 +1278,15 @@ def draw_plotly( color="Methylation density" ) + facet_col = 0 figure = px.imshow( samples_matrix, labels=labels, title=title, aspect="auto", color_continuous_scale=color_scale, - facet_col=0, - facet_col_wrap=facet_cols + facet_col=facet_col, + facet_col_wrap=facet_cols if len(self.samples) > facet_cols else len(self.samples) ) # set facet annotations diff --git a/src/bismarkplot/SeqReader.py b/src/bismarkplot/SeqReader.py new file mode 100644 index 0000000..2b4463e --- /dev/null +++ b/src/bismarkplot/SeqReader.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +import gc +import gzip +import io +import multiprocessing +import os +import shutil +import tempfile +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +from Bio import SeqIO as seqio +from numba import njit + +import polars as pl + + +@njit +def convert_trinuc(trinuc, reverse=False): + """ + Get trinucleotide context from raw trinucleotide + :param trinuc: trinucleotide sequence + :param reverse: is trinucleotide from reversed sequence + :return: trinucleotide context + """ + if reverse: + if trinuc[1] == "C": + return "CG" + elif trinuc[0] == "C": + return "CHG" + else: + return "CHH" + else: + if trinuc[1] == "G": + return "CG" + elif trinuc[2] == "G": + return "CHG" + else: + return "CHH" + + +@njit +def get_trinuc(record_seq: str, reverse=False): + """ + Parse sequence and extract trinucleotide contexts and positions + :param record_seq: sequence + :param reverse: does sequence need to be reversed + :return: tuple(positions, contexts) + """ + positions = [] + trinucs = [] + + record_seq = record_seq.upper() + + nuc = "G" if reverse else "C" + up_shift = 1 if reverse else 3 + down_shift = -2 if reverse else 0 + + for position in range(2 if reverse else 0, len(record_seq) if reverse else len(record_seq) - 2): + if record_seq[position] == nuc: + positions.append(position + 1) + trinuc = record_seq[position + down_shift:position + up_shift] + trinucs.append(convert_trinuc(trinuc, reverse)) + + return positions, trinucs + + +def init_tempfile(temp_dir, name, delete, suffix=".bedGraph.parquet") -> Path: + """ + Init temporary cytosine file + :param temp_dir: directory where file will be created + :param name: filename + :param delete: does file need to be deleted after script completion + :return: temporary file + """ + # temp cytosine file + temp_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete) + + # change name if not None + if name is not None: + new_path = Path(temp_dir) / (Path(name).stem + suffix) + + os.rename(temp_file.name, new_path) + temp_file.name = new_path + + return Path(temp_file.name) + + +class Sequence: + def __init__(self, cytosine_file: str | Path): + """ + Class for extracting cytosine contexts and positions + :param path: path to fasta sequence + :param temp_dir: directory, where temporary file will be created + :param name: filename of temporary file + :param delete: does temporary file need to be deleted after script completion + """ + self.cytosine_file = Path(cytosine_file) + + @classmethod + def from_fasta(cls, path: str | Path, temp_dir: str = "./", name: str = None, delete: bool = True): + """ + :param path: path to fasta sequence + :param temp_dir: directory, where temporary file will be created + :param name: filename of temporary file + :param delete: does temporary file need to be deleted after script completion + """ + cytosine_file = init_tempfile(temp_dir, name, delete, suffix=".fasta.parquet") + sequence = cls(cytosine_file) + + # read sequence into cytosine file + cls.__read_fasta_wrapper(sequence, fasta_path=path) + + return sequence + + @classmethod + def from_preprocessed(cls, path: str | Path): + path = Path(path) + if not path.exists(): + raise FileNotFoundError("Parquet file not found") + try: + pq.read_metadata(path) + + return cls(path) + + except Exception as e: + raise Exception("Failed reading parquet with exception:\n", e) + + @property + def cytosine_file_schema(self): + return pa.schema([ + ("position", pa.int32()), + ("context", pa.dictionary(pa.int8(), pa.utf8())), + ("chr", pa.dictionary(pa.int8(), pa.utf8())), + ("strand", pa.bool_())]) + + def __infer_schema_table(self, handle: io.TextIOBase): + """ + Initialize dummy table with dictionary columns encoded + :param handle: fasta input stream + :return: dummy table + """ + handle.seek(0) + schema = self.cytosine_file_schema + + # get all chromosome names to categorise them + chrom_ids = [record.id for record in seqio.parse(handle, format="fasta")] + # longer list for length be allways greater than chroms or contexts + chrom_ids += [str(chrom_ids[0])] * 3 + + # init other rows of table + contexts = ["CG", "CHG", "CHH"] + contexts += [str(contexts[0])] * (len(chrom_ids) - len(contexts)) + positions = [-1] * len(chrom_ids) + strand = [True] * len(chrom_ids) + + schema_table = pa.Table.from_arrays( + arrays=[positions, contexts, chrom_ids, strand], + schema=schema + ) + + return schema_table + + def __read_fasta(self, handle): + # init arrow parquet writer + arrow_writer = pq.ParquetWriter(self.cytosine_file, self.cytosine_file_schema) + # prepare dummy table with all dictionary columns already mapped + print("Scanning file to get chromosome ids.") + schema_table = self.__infer_schema_table(handle) + + print("Extracting cytosine contexts.") + print("Writing into", self.cytosine_file) + # return to start byte + handle.seek(0) + for record in seqio.parse(handle, "fasta"): + # POSITIVE + # parse sequence and get all trinucleotide positions and contexts + positions, trinuc = get_trinuc(str(record.seq)) + + # convert into arrow table + arrow_table = pa.Table.from_arrays( + arrays=[positions, trinuc, [record.id for _ in positions], [True for _ in positions]], + schema=self.cytosine_file_schema + ) + # unify dictionary keys with dummy table + # and deselect dummy rows + arrow_table = pa.concat_tables([schema_table, arrow_table]).unify_dictionaries()[len(schema_table):] + # write to file + arrow_writer.write(arrow_table) + + print(f"Read chromosome: {record.id}\t+", end="\r") + + # NEGATIVE + positions, trinuc = get_trinuc(str(record.seq), reverse=True) + + arrow_table = pa.Table.from_arrays( + arrays=[positions, trinuc, [record.id] * len(positions), [False for _ in positions]], + schema=schema_table.schema + ) + + arrow_table = pa.concat_tables([schema_table, arrow_table]).unify_dictionaries()[len(schema_table):] + arrow_writer.write(arrow_table) + + print(f"Read chromosome: {record.id}\t+-") + + print("Done reading fasta sequence.\n") + arrow_writer.close() + + def __read_fasta_wrapper(self, fasta_path: str | Path) -> tempfile.TemporaryFile: + fasta_path = Path(fasta_path) + + print("Reading sequence from:", fasta_path) + if fasta_path.suffix == ".gz": + with gzip.open(fasta_path.absolute(), 'rt') as handle: + return self.__read_fasta(handle) + else: + with open(fasta_path.absolute()) as handle: + return self.__read_fasta(handle) + + def get_metadata(self): + return pq.read_metadata(self.cytosine_file) + + +class Mapper: + def __init__(self, path): + self.report_file = path + + @staticmethod + def __map_with_sequence(df_lazy, sequence_df) -> pl.DataFrame: + file_types = [ + pl.col("chr").cast(pl.Categorical), + pl.col("position").cast(pl.Int32) + ] + + # arrow table aligned to genome + chrom_aligned = ( + df_lazy + .with_columns(file_types) + .set_sorted("position") + .join(sequence_df.lazy(), on=["chr", "position"]) + .collect() + ) + + return chrom_aligned + + @staticmethod + def __read_filter_sequence(sequence: Sequence, filter: list) -> pa.Table: + return pq.read_table(sequence.cytosine_file, filters=filter) + + @staticmethod + def __bedGraph_reader(path, batch_size, cpu, skip_rows): + return pl.read_csv_batched( + path, + separator='\t', has_header=False, + new_columns=['chr', 'position', 'count_m'], + columns=[0, 2, 3], + batch_size=batch_size, + n_threads=cpu, + skip_rows=skip_rows, + dtypes=[pl.Utf8, pl.Int64, pl.Float32] + ) + + @staticmethod + def __coverage_reader(path, batch_size, cpu, skip_rows): + return pl.read_csv_batched( + path, + separator='\t', has_header=False, + new_columns=['chr', 'position', 'count_m', 'count_um'], + columns=[0, 2, 4, 5], + batch_size=batch_size, + n_threads=cpu, + skip_rows=skip_rows, + dtypes=[pl.Utf8, pl.Int64, pl.Int32, pl.Int32] + ) + + @classmethod + def __map(cls, where, sequence, batched_reader, cpu, mutations: list[pl.Expr] = None): + pl.enable_string_cache() + genome_metadata = sequence.get_metadata() + genome_rows_read = 0 + + pq_writer = None + + batches = batched_reader.next_batches(cpu) + + while batches: + for batch in batches: + # get batch stats + batch_stats = batch.group_by("chr").agg([ + pl.col("position").max().alias("max"), + pl.col("position").min().alias("min") + ]) + + for chrom in batch_stats["chr"]: + chrom_min, chrom_max = [batch_stats.filter(pl.col("chr") == chrom)[stat][0] for stat in + ["min", "max"]] + + filters = [ + ("chr", "=", chrom), + ("position", ">=", chrom_min), + ("position", "<=", chrom_max) + ] + + chrom_genome = ( + pl.from_arrow(cls.__read_filter_sequence(sequence, filters)) + .with_columns( + pl.when(pl.col("strand") == True).then(pl.lit("+")) + .otherwise(pl.lit("-")) + .cast(pl.Categorical) + .alias("strand") + ) + ) + + + # arrow table aligned to genome + filtered_lazy = batch.lazy().filter(pl.col("chr") == chrom) + filtered_aligned = cls.__map_with_sequence(filtered_lazy, chrom_genome) + + if mutations is not None: + filtered_aligned = filtered_aligned.with_columns(mutations) + + missing_cols = set(["chr", "position", "strand", "context", "count_m", "count_total"]) - set(filtered_aligned.columns) + + for column in missing_cols: + filtered_aligned.with_columns(pl.lit(None).alias(column)) + + filtered_aligned = filtered_aligned.select(["chr", "position", "strand", "context", "count_m", "count_total"]) + + filtered_aligned = filtered_aligned.to_arrow() + if pq_writer is None: + pq_writer = pa.parquet.ParquetWriter( + where, + schema=filtered_aligned.schema + ) + + pq_writer.write(filtered_aligned) + genome_rows_read += len(chrom_genome) + + print("Mapped {rows_read}/{rows_total} ({percent}%) cytosines".format( + rows_read=genome_rows_read, + rows_total=genome_metadata.num_rows, + percent=round(genome_rows_read / genome_metadata.num_rows * 100, 2) + ), end="\r") + + gc.collect() + batches = batched_reader.next_batches(cpu) + + pq_writer.close() + + @staticmethod + def __check_compressed(path: str | Path, temp_dir=None): + path = Path(path) + + if path.suffix == ".gz": + temp_file = tempfile.NamedTemporaryFile(dir=temp_dir) + print(f"Temporarily unpack {path} to {temp_file.name}") + + with gzip.open(path, mode="rb") as file: + shutil.copyfileobj(file, temp_file) + + return temp_file + + else: + return path + + @classmethod + def bedGraph( + cls, + path, + sequence: Sequence, + temp_dir: str = "./", + name: str = None, + delete: bool = True, + batch_size=10 ** 7, + cpu=multiprocessing.cpu_count(), + skip_rows: int = 1 + ): + path = Path(path) + if not path.exists(): + raise FileNotFoundError() + + file = cls.__check_compressed(path, temp_dir) + path = file.name + + report_file = init_tempfile(temp_dir, name, delete, suffix=".bedGraph.parquet") + mapper = Mapper(report_file) + + path = Path(path) + print(f"Started reading bedGraph file from {path}") + + bedGraph_reader = cls.__bedGraph_reader(path, batch_size, cpu, skip_rows) + + mutations = [ + pl.col("count_m") / 100, + pl.lit(1).alias("count_total") + ] + + cls.__map(mapper.report_file, sequence, bedGraph_reader, cpu, mutations) + + print(f"\nDone reading bedGraph sequence\nTable saved to {mapper.report_file}") + + return mapper + + @classmethod + def coverage( + cls, + path, + sequence: Sequence, + temp_dir: str = "./", + name: str = None, + delete: bool = True, + batch_size=10 ** 7, + cpu=multiprocessing.cpu_count(), + skip_rows: int = 1 + ): + path = Path(path) + if not path.exists(): + raise FileNotFoundError() + + file = cls.__check_compressed(path, temp_dir) + path = file.name + + report_file = init_tempfile(temp_dir, name, delete, suffix=".cov.parquet") + mapper = Mapper(report_file) + + path = Path(path) + print(f"Started reading coverage file from {path}") + + coverage_reader = cls.__coverage_reader(path, batch_size, cpu, skip_rows) + + mutations = [ + (pl.col("count_m") + pl.col("count_um")).alias("count_total") + ] + + cls.__map(mapper.report_file, sequence, coverage_reader, cpu, mutations) + + print(f"\nDone reading coverage file\nTable saved to {mapper.report_file}") + + return mapper diff --git a/src/bismarkplot/base.py b/src/bismarkplot/base.py index 8a4acb2..d8ea458 100644 --- a/src/bismarkplot/base.py +++ b/src/bismarkplot/base.py @@ -3,7 +3,7 @@ import polars as pl from pyreadr import write_rds -from src.bismarkplot.utils import remove_extension +from .utils import remove_extension class BismarkBase: diff --git a/src/bismarkplot/console_metagene.py b/src/bismarkplot/console_metagene.py index c1b9123..17ad4be 100644 --- a/src/bismarkplot/console_metagene.py +++ b/src/bismarkplot/console_metagene.py @@ -22,14 +22,14 @@ parser.add_argument('-m', '--mlength', help='minimal length in bp of gene', type=int, default=4000) parser.add_argument('-w', '--gwindows', help='number of windows for genes', type=int, default=100) -parser.add_argument('--line', help='line-plot enabled', action='store_true', default=False) -parser.add_argument('--heatmap', help='heat-map enabled', action='store_true', default=False) -parser.add_argument('--box', help='box-plot enabled', action='store_true', default=False) -parser.add_argument('--violin', help='violin-plot enabled', action='store_true', default=False) +parser.add_argument('--line', help='line-plot enabled', action='store_true', default=True) +parser.add_argument('--heatmap', help='heat-map enabled', action='store_true', default=True) +parser.add_argument('--box', help='box-plot enabled', action='store_true', default=True) +parser.add_argument('--violin', help='violin-plot enabled', action='store_true', default=True) parser.add_argument('-S', '--smooth', help='windows for smoothing', type=float, default=10) parser.add_argument('-L', '--labels', help='labels for plots', nargs='+') -parser.add_argument('-C', '--confidence', help='probability for confidence bands for line-plot. 0 if disabled', type=float, default=0) +parser.add_argument('-C', '--confidence', help='probability for confidence bands for line-plot. 0 if disabled', type=float, default=.95) parser.add_argument('-H', help='vertical resolution for heat-map', type=int, default=100, dest="vresolution") parser.add_argument('-V', help='vertical resolution for heat-map', type=int, default=100, dest="hresolution") parser.add_argument("--dpi", help="dpi of output plot", type=int, default=200) @@ -50,13 +50,13 @@ def main(): file=args.genome ) if args.region == "tss": - genome = genome.near_TSS(min_length = args.mlength, flank_length= args.flength) + genome = genome.near_TSS(min_length=args.mlength, flank_length=args.flength) elif args.region == "tes": - genome = genome.near_TES(min_length = args.mlength, flank_length= args.flength) + genome = genome.near_TES(min_length=args.mlength, flank_length=args.flength) elif args.region == "exon": - genome = genome.exon(min_length = args.mlength) + genome = genome.exon(min_length=args.mlength) else: - genome = genome.gene_body(min_length = args.mlength, flank_length= args.flength) + genome = genome.gene_body(min_length=args.mlength, flank_length=args.flength) bismark = MetageneFiles.from_list( filenames=args.filename, diff --git a/src/bismarkplot/genome.py b/src/bismarkplot/genome.py index d72328b..87edb44 100644 --- a/src/bismarkplot/genome.py +++ b/src/bismarkplot/genome.py @@ -1,5 +1,7 @@ -import polars as pl +from __future__ import annotations +import polars as pl +from pathlib import Path class Genome: def __init__(self, genome: pl.LazyFrame): @@ -20,7 +22,7 @@ def __init__(self, genome: pl.LazyFrame): @classmethod def from_custom(cls, - file: str, + file: str | Path, chr_col: int = 0, start_col: int = 1, end_col: int = 2, diff --git a/src/bismarkplot/levels.py b/src/bismarkplot/levels.py index 0e2cfd3..3d0abb6 100644 --- a/src/bismarkplot/levels.py +++ b/src/bismarkplot/levels.py @@ -9,7 +9,7 @@ import plotly.graph_objects as go -from src.bismarkplot.utils import approx_batch_num, interval +from .utils import approx_batch_num, interval class ChrLevels: diff --git a/src/bismarkplot/utils.py b/src/bismarkplot/utils.py index 484954e..9f3aa6b 100644 --- a/src/bismarkplot/utils.py +++ b/src/bismarkplot/utils.py @@ -6,6 +6,13 @@ from scipy import stats +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def remove_extension(path): re.sub("\.[^./]+$", "", path)