diff --git a/README.md b/README.md index d4c358f..43e5759 100644 --- a/README.md +++ b/README.md @@ -123,30 +123,32 @@ Below we will show the basic BismarkPlot workflow. ### Single sample ```python +import src.bismarkplot.genome import bismarkplot + # Firstly, we need to read the regions annotation (e.g. reference genome .gff) -genome = bismarkplot.Genome.from_gff("path/to/genome.gff") +genome = src.bismarkplot.genome.Genome.from_gff("path/to/genome.gff") # Next we need to filter regions of interest from the genome genes = genome.gene_body(min_length=4000, flank_length=2000) # Now we need to calculate metagene data metagene = bismarkplot.Metagene.from_file( - file = "path/to/CX_report.txt", - genome=genes, # filtered regions - upstream_windows = 500, - gene_windows = 1000, - downstream_windows = 500, - batch_size= 10**7 # number of lines to be read simultaneously + file="path/to/CX_report.txt", + genome=genes, # filtered regions + upstream_windows=500, + gene_windows=1000, + downstream_windows=500, + batch_size=10 ** 7 # number of lines to be read simultaneously ) # Our metagene contains all methylation contexts and both strands, so we need to filter it (as in dplyr) -filtered = metagene.filter(context = "CG", strand = "+") +filtered = metagene.filter(context="CG", strand="+") # We are ready to plot -lp = filtered.line_plot() # line plot data -lp.draw().savefig("path/to/lp.pdf") # matplotlib.Figure +lp = filtered.line_plot() # line plot data +lp.draw().savefig("path/to/lp.pdf") # matplotlib.Figure hm = filtered.heat_map(ncol=200, nrow=200) -hm.draw().savefig("path/to/hm.pdf") # matplotlib.Figure +hm.draw().savefig("path/to/hm.pdf") # matplotlib.Figure ``` Output for _Brachypodium distachyon_: @@ -280,17 +282,19 @@ Output for _Brachypodium distachyon_: ```python # For analyzing samples with different reference genomes, we need to initialize several genomes instances +import src.bismarkplot.genome + genome_filenames = ["arabidopsis.gff", "brachypodium.gff", "cucumis.gff", "mus.gff"] reports_filenames = ["arabidopsis.txt", "brachypodium.txt", "cucumis.txt", "mus.txt"] genomes = [ - bismarkplot.Genome.from_gff(file).gene_body(...) for file in genome_filenames + src.bismarkplot.genome.Genome.from_gff(file).gene_body(...) for file in genome_filenames ] # Now we read reports metagenes = [] for report, genome in zip(reports_filenames, genomes): - metagene = bismarkplot.Metagene(report, genome = genome, ...) + metagene = bismarkplot.Metagene(report, genome=genome, ...) metagenes.append(metagene) # Initialize MetageneFiles @@ -315,26 +319,29 @@ Output: Other genomic regions from .gff can be analyzed too with ```.exon``` or ```.near_tss/.near_tes``` option for ```bismarkplot.Genome``` ```python +import src.bismarkplot.genome + exons = [ - bismarkplot.Genome.from_gff(file).exon(min_length=100) for file in genome_filenames + src.bismarkplot.genome.Genome.from_gff(file).exon(min_length=100) for file in genome_filenames ] metagenes = [] for report, exon in zip(reports_filenames, exons): - metagene = bismarkplot.Metagene(report, genome = exon, - upstream_windows = 0, # !!! - downstream_windows = 0, # !!! + metagene = bismarkplot.Metagene(report, genome=exon, + upstream_windows=0, # !!! + downstream_windows=0, # !!! ...) metagenes.append(metagene) # OR tss = [ - bismarkplot.Genome.from_gff(file).near_tss(min_length = 2000, flank_length = 2000) for file in genome_filenames + src.bismarkplot.genome.Genome.from_gff(file).near_tss(min_length=2000, flank_length=2000) for file in + genome_filenames ] metagenes = [] for report, t in zip(reports_filenames, tss): - metagene = bismarkplot.Metagene(report, genome = t, - upstream_windows = 1000,# same number of windows - gene_windows = 1000, # same number of windows - downstream_windows = 0, # !!! + metagene = bismarkplot.Metagene(report, genome=t, + upstream_windows=1000, # same number of windows + gene_windows=1000, # same number of windows + downstream_windows=0, # !!! ...) metagenes.append(metagene) ``` @@ -356,23 +363,25 @@ TSS output: BismarkPlot allows user to visualize chromosome methylation levels across full genome ```python +import src.bismarkplot.levels import bismarkplot -chr = bismarkplot.ChrLevels.from_file( + +chr = src.bismarkplot.levels.ChrLevels.from_file( "path/to/CX_report.txt", - window_length=10**5, # window length in bp - batch_size=10**7, - chr_min_length = 10**6, # minimum chr length in bp + window_length=10 ** 5, # window length in bp + batch_size=10 ** 7, + chr_min_length=10 ** 6, # minimum chr length in bp ) fig, axes = plt.subplots() for context in ["CG", "CHG", "CHH"]: - chr.filter(strand="+", context=context).draw( - (fig, axes), # to plot contexts on same axes - smooth=10, # window number for smoothing - label=context # labels for lines - ) + chr.filter(strand="+", context=context).draw( + (fig, axes), # to plot contexts on same axes + smooth=10, # window number for smoothing + label=context # labels for lines + ) -fig.savefig(f"chrom.pdf", dpi = 200) +fig.savefig(f"chrom.pdf", dpi=200) ``` Output for _Arabidopsis thaliana_: diff --git a/src/bismarkplot/BismarkPlot.py b/src/bismarkplot/BismarkPlot.py index 746fb9d..f149ca4 100644 --- a/src/bismarkplot/BismarkPlot.py +++ b/src/bismarkplot/BismarkPlot.py @@ -1,9 +1,5 @@ -import gzip import re -from functools import cache -from collections import Counter from multiprocessing import cpu_count -from os.path import getsize import polars as pl @@ -15,829 +11,16 @@ from matplotlib.figure import Figure from scipy.signal import savgol_filter -from scipy.spatial.distance import pdist -from scipy.cluster.hierarchy import linkage, leaves_list from scipy import stats from pandas import DataFrame as pdDataFrame from pyreadr import write_rds -from dynamicTreeCut import cutreeHybrid +from src.bismarkplot.base import BismarkBase, BismarkFilesBase +from src.bismarkplot.clusters import Clustering +from src.bismarkplot.utils import remove_extension, approx_batch_num, hm_flank_lines - -def remove_extension(path): - re.sub("\.[^./]+$", "", path) - - -def approx_batch_num(path, batch_size, check_lines=1000): - size = getsize(path) - - length = 0 - with open(path, "rb") as file: - for _ in range(check_lines): - length += len(file.readline()) - - return round(np.ceil(size / (length / check_lines * batch_size))) - - -def hm_flank_lines(axes: Axes, upstream_windows: int, gene_windows: int, downstream_windows: int): - """ - Add flank lines to the given axis (for line plot) - """ - x_ticks = [] - x_labels = [] - if upstream_windows > 0: - x_ticks.append(upstream_windows - .5) - x_labels.append('TSS') - if downstream_windows > 0: - x_ticks.append(gene_windows + downstream_windows - .5) - x_labels.append('TES') - - if x_ticks and x_labels: - axes.set_xticks(x_ticks) - axes.set_xticklabels(x_labels) - for tick in x_ticks: - axes.axvline(x=tick, linestyle='--', color='k', alpha=.3) - - -class Genome: - def __init__(self, genome: pl.LazyFrame): - """ - Class for storing and manipulating genome DataFrame. - - Genome Dataframe columns: - - +------+--------+-------+-------+----------+------------+ - | chr | strand | start | end | upstream | downstream | - +======+========+=======+=======+==========+============+ - | Utf8 | Utf8 | Int32 | Int32 | Int32 | Int32 | - +------+--------+-------+-------+----------+------------+ - - :param genome: :class:`pl.LazyFrame` with genome data. - """ - self.genome = genome - - @classmethod - def from_gff(cls, file: str): - """ - Constructor with parameters for default gff file. - - :param file: path to genome.gff. - """ - comment_char = '#' - has_header = False - - id_regex = "^ID=([^;]+)" - - genes = ( - pl.scan_csv( - file, - comment_char=comment_char, - has_header=has_header, - separator='\t', - new_columns=['chr', 'source', 'type', 'start', - 'end', 'score', 'strand', 'frame', 'attribute'], - dtypes={'start': pl.Int32, 'end': pl.Int32, 'chr': pl.Utf8} - ) - .with_columns( - pl.col("attribute").str.extract(id_regex).alias("id") - ) - .select(['chr', 'type', 'start', 'end', 'strand', "id"]) - ) - - print(f"Genome read from {file}") - return cls(genes) - - def gene_body(self, min_length: int = 4000, flank_length: int = 2000) -> pl.DataFrame: - """ - Filter type == gene from gff. - - :param min_length: minimal length of genes. - :param flank_length: length of the flanking region. - :return: :class:`pl.LazyFrame` with genes and their flanking regions. - """ - genes = self.__filter_genes( - self.genome, 'gene', min_length, flank_length) - genes = self.__trim_genes(genes, flank_length).collect() - return self.__check_empty(genes) - - def exon(self, min_length: int = 100) -> pl.DataFrame: - """ - Filter type == exon from gff. - - :param min_length: minimal length of exons. - :return: :class:`pl.LazyFrame` with exons. - """ - flank_length = 0 - genes = self.__filter_genes( - self.genome, 'exon', min_length, flank_length) - genes = self.__trim_genes(genes, flank_length).collect() - return self.__check_empty(genes) - - def cds(self, min_length: int = 100) -> pl.DataFrame: - """ - Filter type == CDS from gff. - - :param min_length: minimal length of CDS. - :return: :class:`pl.LazyFrame` with CDS. - """ - flank_length = 0 - genes = self.__filter_genes( - self.genome, 'CDS', min_length, flank_length) - genes = self.__trim_genes(genes, flank_length).collect() - return self.__check_empty(genes) - - def near_TSS(self, min_length: int = 4000, flank_length: int = 2000): - """ - Get region near TSS - upstream and same length from TSS. - - :param min_length: minimal length of genes. - :param flank_length: length of the flanking region. - :return: :class:`pl.LazyFrame` with genes and their flanking regions. - """ - - # decided not to use this - ''' - upstream_length = ( - # when before length is enough - # we set upstream length to specified - pl.when(pl.col('upstream') >= flank_length).then(flank_length) - # when genes are intersecting (current start < previous end) - # we don't take this as upstream region - .when(pl.col('upstream') < 0).then(0) - # when length between genes is not enough for full specified length - # we divide it into half - .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) - ) - ''' - upstream_length = flank_length - - gene_type = "gene" - genes = self.__filter_genes( - self.genome, gene_type, min_length, flank_length) - genes = ( - genes - .groupby(['chr', 'strand'], maintain_order=True).agg([ - pl.col('start'), - # upstream shift - (pl.col('start').shift(-1) - pl.col('end')).shift(1) - .fill_null(flank_length) - .alias('upstream'), - pl.col('id') - ]) - .explode(['start', 'upstream', 'id']) - .with_columns([ - (pl.col('start') - upstream_length).alias('upstream'), - (pl.col("start") + flank_length).alias("end") - ]) - .with_columns(pl.col("end").alias("downstream")) - ).collect() - - return self.__check_empty(genes) - - def near_TES(self, min_length: int = 4000, flank_length: int = 2000): - """ - Get region near TES - downstream and same length from TES. - - :param min_length: minimal length of genes. - :param flank_length: length of the flanking region. - :return: :class:`pl.LazyFrame` with genes and their flanking regions. - """ - - # decided not to use this - ''' - downstream_length = ( - # when before length is enough - # we set upstream length to specified - pl.when(pl.col('downstream') >= flank_length).then(flank_length) - # when genes are intersecting (current start < previous end) - # we don't take this as upstream region - .when(pl.col('downstream') < 0).then(0) - # when length between genes is not enough for full specified length - # we divide it into half - .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) - ) - ''' - downstream_length = flank_length - - gene_type = "gene" - genes = self.__filter_genes( - self.genome, gene_type, min_length, flank_length) - genes = ( - genes - .groupby(['chr', 'strand'], maintain_order=True).agg([ - pl.col('end'), - # downstream shift - (pl.col('start').shift(-1) - pl.col('end')) - .fill_null(flank_length) - .alias('downstream'), - pl.col('id') - ]) - .explode(['end', 'downstream', 'id']) - .with_columns([ - (pl.col('end') + downstream_length).alias('downstream'), - (pl.col("end") - flank_length).alias("start") - ]) - .with_columns(pl.col("start").alias("upstream")) - ).collect() - - return self.__check_empty(genes) - - def other(self, gene_type: str, min_length: int = 1000, flank_length: int = 100) -> pl.DataFrame: - """ - Filter by selected type. - - :param gene_type: selected type from gff. Cases need to match. - :param min_length: minimal length of genes. - :param flank_length: length of the flanking region. - :return: :class:`pl.LazyFrame` with genes and their flanking regions. - """ - genes = self.__filter_genes( - self.genome, gene_type, min_length, flank_length) - genes = self.__trim_genes(genes, flank_length).collect() - return self.__check_empty(genes) - - @staticmethod - def __filter_genes(genes, gene_type, min_length, flank_length): - genes = genes.filter(pl.col('type') == gene_type).drop('type') - - # filter genes, which start < flank_length - if flank_length > 0: - genes = genes.filter(pl.col('start') > flank_length) - # filter genes which don't pass length threshold - if min_length > 0: - genes = genes.filter((pl.col('end') - pl.col('start')) > min_length) - - return genes - - @staticmethod - def __trim_genes(genes, flank_length) -> pl.LazyFrame: - # upstream shift - # calculates length to previous gene on same chr_strand - length_before = (pl.col('start').shift(-1) - pl.col('end')).shift(1).fill_null(flank_length) - # downstream shift - # calculates length to next gene on same chr_strand - length_after = (pl.col('start').shift(-1) - pl.col('end')).fill_null(flank_length) - - # decided not to use this conditions - ''' - upstream_length_conditioned = ( - # when before length is enough - # we set upstream length to specified - pl.when(pl.col('upstream') >= flank_length).then(flank_length) - # when genes are intersecting (current start < previous end) - # we don't take this as upstream region - .when(pl.col('upstream') < 0).then(0) - # when length between genes is not enough for full specified length - # we divide it into half - .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) - ) - - downstream_length_conditioned = ( - # when before length is enough - # we set upstream length to specified - pl.when(pl.col('downstream') >= flank_length).then(flank_length) - # when genes are intersecting (current start < previous end) - # we don't take this as upstream region - .when(pl.col('downstream') < 0).then(0) - # when length between genes is not enough for full specified length - # we divide it into half - .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) - ) - ''' - - return ( - genes - .groupby(['chr', 'strand'], maintain_order=True).agg([ - pl.col('start'), - pl.col('end'), - length_before.alias('upstream'), - length_after.alias('downstream'), - pl.col('id') - ]) - .explode(['start', 'end', 'upstream', 'downstream', 'id']) - .with_columns([ - # calculates length of region - (pl.col('start') - flank_length).alias('upstream'), - # calculates length of region - (pl.col('end') + flank_length).alias('downstream') - ]) - ) - - @staticmethod - def __check_empty(genes): - if len(genes) > 0: - return genes - else: - raise Exception( - "Genome DataFrame is empty. Are you sure input file is valid?") - - -class BismarkBase: - """ - Base class for :class:`Metagene` and plots. - """ - - def __init__(self, bismark_df: pl.DataFrame, **kwargs): - """ - Base class for Bismark data. - - DataFrame Structure: - - +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ - | chr | strand | context | gene | fragment | sum | count | - +=================+=============+=====================+======================+==================+================+=========================================+ - | Categorical | Categorical | Categorical | Categorical | Int32 | Int32 | Int32 | - +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ - | chromosome name | strand | methylation context | position of cytosine | fragment in gene | sum methylated | count of all cytosines in this position | - +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ - - - :param bismark_df: pl.DataFrame with cytosine methylation status. - :param upstream_windows: Number of upstream windows. Required. - :param gene_windows: Number of gene windows. Required. - :param downstream_windows: Number of downstream windows. Required. - :param strand: Strand if filtered. - :param context: Methylation context if filtered. - :param plot_data: Data for plotting. - """ - self.bismark: pl.DataFrame = bismark_df - - self.upstream_windows: int = kwargs.get("upstream_windows") - self.downstream_windows: int = kwargs.get("downstream_windows") - self.gene_windows: int = kwargs.get("gene_windows") - self.plot_data: pl.DataFrame = kwargs.get("plot_data") - self.context: str = kwargs.get("context") - self.strand: str = kwargs.get("strand") - - @property - def metadata(self) -> dict: - """ - :return: Bismark metadata in dict - """ - return { - "upstream_windows": self.upstream_windows, - "downstream_windows": self.downstream_windows, - "gene_windows": self.gene_windows, - "plot_data": self.plot_data, - "context": self.context, - "strand": self.strand - } - - def save_rds(self, filename, compress: bool = False): - """ - Save Bismark DataFrame in Rds. - - :param filename: Path for file. - :param compress: Whether to compress to gzip or not. - """ - write_rds(filename, self.bismark.to_pandas(), - compress="gzip" if compress else None) - - def save_tsv(self, filename, compress=False): - """ - Save Bismark DataFrame in TSV. - - :param filename: Path for file. - :param compress: Whether to compress to gzip or not. - """ - if compress: - with gzip.open(filename + ".gz", "wb") as file: - # noinspection PyTypeChecker - self.bismark.write_csv(file, separator="\t") - else: - self.bismark.write_csv(filename, separator="\t") - - @property - def total_windows(self): - return self.upstream_windows + self.downstream_windows + self.gene_windows - - def __len__(self): - return len(self.bismark) - - -class Clustering(BismarkBase): - """ - Class for clustering genes within sample - """ - - def __init__(self, bismark_df: pl.DataFrame, count_threshold=5, dist_method="euclidean", clust_method="average", **kwargs): - """ - :param bismark_df: :class:polars.DataFrame with genes data - :param count_threshold: Minimum counts per fragment - :param dist_method: Method for evaluating distance - :param clust_method: Method for hierarchical clustering - """ - super().__init__(bismark_df, **kwargs) - - if self.bismark["fragment"].max() > 50: - print(f"WARNING: too many windows ({self.bismark['fragment'].max() + 1}), clusterisation may take very long time") - - grouped = ( - self.bismark.lazy() - .with_columns((pl.col("sum") / pl.col("count")).alias("density")) - .group_by(["chr", "strand", "gene", "context"]) - .agg([pl.col("density"), - pl.col("fragment"), - pl.sum("count").alias("gene_count"), - pl.count("fragment").alias("count")]) - ).collect() - - print(f"Starting with:\t{len(grouped)}") - - by_count = grouped.filter(pl.col("gene_count") > (count_threshold * pl.col("count"))) - - print(f"Left after count theshold filtration:\t{len(by_count)}") - - by_count = by_count.filter(pl.col("count") == self.total_windows) - - print(f"Left after empty windows filtration:\t{len(by_count)}") - - if len(by_count) == 0: - print("All genes have empty windows, exiting") - raise ValueError("All genes have empty windows") - - by_count = by_count.explode(["density", "fragment"]).drop(["gene_count", "count"]).fill_nan(0) - - unpivot = by_count.pivot( - index=["chr", "strand", "gene"], - values="density", - columns="fragment", - aggregate_function="sum" - ).select( - ["chr", "strand", "gene"] + list(map(str, range(self.total_windows))) - ).with_columns( - pl.col("gene").alias("label") - ) - - self.gene_labels = unpivot.with_columns(pl.col("label").cast(pl.Utf8))["label"].to_numpy() - self.matrix = unpivot[list(map(str, range(self.total_windows)))].to_numpy() - - self.gene_labels = self.gene_labels[~np.isnan(self.matrix).any(axis=1)] - self.matrix = self.matrix[~np.isnan(self.matrix).any(axis=1), :] - - # dist matrix - print("Distances calculation") - self.dist = pdist(self.matrix, metric=dist_method) - # linkage matrix - print("Linkage calculation and minimizing distances") - self.linkage = linkage(self.dist, method=clust_method, optimal_ordering=True) - - self.order = leaves_list(self.linkage) - - def modules(self, **kwargs): - return Modules(self.gene_labels, self.matrix, self.linkage, self.dist, - windows={ - key: self.metadata[key] for key in ["upstream_windows", "gene_windows", "downstream_windows"] - }, - **kwargs) - - # TODO: rewrite save_rds, save_tsv - - def draw( - self, - fig_axes: tuple = None, - title: str = None - ) -> Figure: - """ - Draws heat-map on given :class:`matplotlib.Axes` or makes them itself. - - :param fig_axes: Tuple with (fig, axes) from :meth:`matplotlib.plt.subplots`. - :param title: Title of the plot. - :return: - """ - if fig_axes is None: - plt.clf() - fig, axes = plt.subplots() - else: - fig, axes = fig_axes - - vmin = 0 - vmax = np.max(np.array(self.plot_data)) - - image = axes.imshow( - self.matrix[self.order, :], - interpolation="nearest", aspect='auto', - cmap=colormaps['cividis'], - vmin=vmin, vmax=vmax - ) - axes.set_title(title) - axes.set_xlabel('Position') - axes.set_ylabel('') - - hm_flank_lines(axes, self.upstream_windows, self.gene_windows, self.downstream_windows) - - axes.set_yticks([]) - plt.colorbar(image, ax=axes, label='Methylation density') - - return fig - - -class Modules: - """ - Class for module construction and visualization of clustered genes - """ - def __init__(self, labels: list, matrix: np.ndarray, linkage, distance, windows, **kwargs): - if not len(labels) == len(matrix): - raise ValueError("Length of labels and methylation matrix labels don't match") - - self.labels, self.matrix = labels, matrix - self.linkage, self.distance = linkage, distance - - self.__windows = windows - - self.tree = self.__dynamic_tree_cut(**kwargs) - - def recalculate(self, **kwargs): - """ - Recalculate tree with another params - - :param kwargs: any kwargs to cutreeHybrid from dynamicTreeCut - """ - self.tree = self.__dynamic_tree_cut(**kwargs) - - @cache - def __dynamic_tree_cut(self, **kwargs): - return cutreeHybrid(self.linkage, self.distance, **kwargs) - - @property - def __format__table(self) -> pl.DataFrame: - return pl.DataFrame( - {"gene_labels": list(self.labels)} | - {key: list(self.tree[key]) for key in ["labels", "cores", "smallLabels", "onBranch"]} - ) - - def save_rds(self, filename, compress: bool = False): - """ - Save module data in Rds. - - :param filename: Path for file. - :param compress: Whether to compress to gzip or not. - """ - write_rds(filename, self.__format__table.to_pandas(), - compress="gzip" if compress else None) - - def save_tsv(self, filename, compress=False): - """ - Save module data in TSV. - - :param filename: Path for file. - :param compress: Whether to compress to gzip or not. - """ - if compress: - with gzip.open(filename + ".gz", "wb") as file: - # noinspection PyTypeChecker - self.__format__table.write_csv(file, separator="\t") - else: - self.__format__table.write_csv(filename, separator="\t") - - def draw( - self, - fig_axes: tuple = None, - title: str = None, - show_labels=True, - show_size=False - ) -> Figure: - """ - Method for visualization of moduled genes. Every row of heat-map represents an average methylation - profile of genes of the module. - - :param fig_axes: tuple(Fig, Axes) to plot - :param title: Title of the plot - :param show_labels: Enable/disable module number labels - :param show_size: Enable/disable module size labels (in brackets) - """ - - me_matrix, me_labels = [], [] - label_stats = Counter(self.tree["labels"]) - - # iterate every label - for label in label_stats.keys(): - # select genes from module - module_genes = self.tree["labels"] == label - # append mean module pattern - me_matrix.append(self.matrix[module_genes, :].mean(axis=0)) - - me_labels.append(f"{label} ({label_stats[label]})" if show_size else str(label)) - - me_matrix, me_labels = np.stack(me_matrix), np.stack(me_labels) - # sort matrix to minimize distances between modules - order = leaves_list(linkage( - y=pdist(me_matrix, metric="euclidean"), - method="average", - optimal_ordering=True - )) - - me_matrix, me_labels = me_matrix[order, :], me_labels[order] - - if fig_axes is None: - plt.clf() - fig, axes = plt.subplots() - else: - fig, axes = fig_axes - - vmin = 0 - vmax = np.max(np.array(me_matrix)) - - image = axes.imshow( - me_matrix, - interpolation="nearest", aspect='auto', - cmap=colormaps['cividis'], - vmin=vmin, vmax=vmax - ) - - if show_labels: - axes.set_yticks(np.arange(.5, len(me_labels), 1)) - axes.set_yticklabels(me_labels) - else: - axes.set_yticks([]) - - axes.set_title(title) - axes.set_xlabel('Position') - axes.set_ylabel('Module') - # axes.yaxis.tick_right() - - hm_flank_lines( - axes, - self.__windows["upstream_windows"], - self.__windows["gene_windows"], - self.__windows["downstream_windows"], - ) - - plt.colorbar(image, ax=axes, label='Methylation density', - # orientation="horizontal", location="top" - ) - - return fig - - - - -class ChrLevels: - def __init__(self, df: pl.DataFrame) -> None: - self.bismark = df - self.plot_data = ( - df - .sort(["chr", "window"]) - .with_row_count("fragment") - .group_by(["chr", "fragment"], maintain_order=True) - .agg([pl.sum("sum"), pl.sum("count")]) - .with_columns((pl.col("sum") / pl.col("count")).alias("density")) - ) - - @classmethod - def from_file( - cls, - file: str, - chr_min_length = 10**6, - window_length: int = 10**6, - batch_size: int = 10 ** 6, - cpu: int = cpu_count() - ): - """ - Initialize ChrLevels with CX_report file - - :param file: Path to file - :param chr_min_length: Minimum length of chromosome to be analyzed - :param window_length: Length of windows in bp - :param cpu: How many cores to use. Uses every physical core by default - :param batch_size: Number of rows to read by one CPU core - """ - cpu = cpu if cpu is not None else cpu_count() - - bismark = pl.read_csv_batched( - file, - separator='\t', has_header=False, - new_columns=['chr', 'position', 'strand', - 'count_m', 'count_um', 'context'], - columns=[0, 1, 2, 3, 4, 5], - batch_size=batch_size, - n_threads=cpu - ) - read_approx = approx_batch_num(file, batch_size) - read_batches = 0 - - total = None - - batches = bismark.next_batches(cpu) - print(f"Reading from {file}") - while batches: - for df in batches: - df = ( - df.lazy() - .filter((pl.col('count_m') + pl.col('count_um') != 0)) - .group_by(["strand", "chr"]) - .agg([ - pl.col("context"), - (pl.col("position") / window_length).floor().alias("window").cast(pl.Int32), - ((pl.col('count_m')) / (pl.col('count_m') + pl.col('count_um'))).alias('density').cast(pl.Float32), - (pl.max("position") - pl.min("position")).alias("length") - ]) - .filter(pl.col("length") > chr_min_length) - .explode(["context", "window", "density"]) - .group_by(by=['chr', 'strand', 'context', 'window']) - .agg([ - pl.sum('density').alias('sum'), - pl.count('density').alias('count') - ]) - .drop_nulls(subset=['sum']) - ).collect() - if total is None and len(df) == 0: - raise Exception( - "Error reading Bismark file. Check format or genome. No joins on first batch.") - elif total is None: - total = df - else: - total = total.extend(df) - - read_batches += 1 - print( - f"\tRead {read_batches}/{read_approx} batch | Total size - {round(total.estimated_size('mb'), 1)}Mb RAM", end="\r") - batches = bismark.next_batches(cpu) - print("DONE") - - return cls(total) - - def save_plot_rds(self, path, compress: bool = False): - """ - Saves plot data in a rds DataFrame with columns: - - +----------+---------+ - | fragment | density | - +==========+=========+ - | Int | Float | - +----------+---------+ - """ - write_rds(path, self.plot_data.to_pandas(), - compress="gzip" if compress else None) - - def filter(self, context: str = None, strand: str = None, chr: str = None): - """ - :param context: Methylation context (CG, CHG, CHH) to filter (only one). - :param strand: Strand to filter (+ or -). - :param chr: Chromosome name to filter. - :return: Filtered :class:`Bismark`. - """ - context_filter = self.bismark["context"] == context if context is not None else True - strand_filter = self.bismark["strand"] == strand if strand is not None else True - chr_filter = self.bismark["chr"] == chr if chr is not None else True - - if context_filter is None and strand_filter is None and chr_filter is None: - return self - else: - return self.__class__(self.bismark.filter(context_filter & strand_filter & chr_filter)) - - def draw( - self, - fig_axes: tuple = None, - smooth: int = 10, - label: str = None, - linewidth: float = 1.0, - linestyle: str = '-', - ) -> Figure: - - if fig_axes is None: - fig, axes = plt.subplots() - else: - fig, axes = fig_axes - - ticks_data = self.plot_data.group_by("chr", maintain_order = True).agg(pl.min("fragment")) - - x_lines = ticks_data["fragment"].to_numpy() - x_lines = np.append(x_lines, self.plot_data["fragment"].max()) - - x_ticks = (x_lines[1:] + x_lines[:-1]) // 2 - - # get middle ticks - - x_labels = ticks_data["chr"].to_list() - - data = self.plot_data["density"].to_numpy() - - polyorder = 3 - window = smooth if smooth > polyorder else polyorder + 1 - - if smooth: - data = savgol_filter(data, window, 3, mode='nearest') - - x = np.arange(len(data)) - data = data * 100 # convert to percents - axes.plot(x, data, label=label, - linestyle=linestyle, linewidth=linewidth) - - axes.set_xticks(x_ticks) - axes.set_xticklabels(x_labels) - - axes.legend() - axes.set_ylabel('Methylation density, %') - axes.set_xlabel('Position') - - for tick in x_lines: - axes.axvline(x=tick, linestyle='--', color='k', alpha=.1) - - fig.set_size_inches(12, 5) - - return fig +import plotly.graph_objects as go class Metagene(BismarkBase): @@ -853,7 +36,8 @@ def from_file( gene_windows: int = 2000, downstream_windows: int = 0, batch_size: int = 10 ** 6, - cpu: int = cpu_count() + cpu: int = cpu_count(), + sumfunc: str = "mean" ): """ Constructor from Bismark coverage2cytosine output. @@ -875,7 +59,7 @@ def from_file( bismark_df = cls.__read_bismark_batches(file, genome, upstream_windows, gene_windows, downstream_windows, - batch_size, cpu) + batch_size, cpu, sumfunc) return cls(bismark_df, upstream_windows=upstream_windows, @@ -890,21 +74,22 @@ def __read_bismark_batches( gene_windows: int = 2000, downstream_windows: int = 500, batch_size: int = 10 ** 7, - cpu: int = cpu_count() + cpu: int = cpu_count(), + sumfunc: str = "mean" ) -> pl.DataFrame: cpu = cpu if cpu is not None else cpu_count() # enable string cache for categorical comparison pl.enable_string_cache(True) - # POLARS EXPRESSIONS + # *** POLARS EXPRESSIONS *** # cast genome columns to type to join - gene_columns = [ + GENE_COLUMNS = [ pl.col('strand').cast(pl.Categorical), pl.col('chr').cast(pl.Categorical) ] # cast report columns to optimized type - df_columns = [ + DF_COLUMNS = [ pl.col('position').cast(pl.Int32), pl.col('chr').cast(pl.Categorical), pl.col('strand').cast(pl.Categorical), @@ -914,23 +99,23 @@ def __read_bismark_batches( ] # upstream region position check - upstream_region = pl.col('position') < pl.col('start') + UP_REGION = pl.col('position') < pl.col('start') # body region position check - body_region = (pl.col('start') <= pl.col('position')) & (pl.col('position') <= pl.col('end')) + BODY_REGION = (pl.col('start') <= pl.col('position')) & (pl.col('position') <= pl.col('end')) # downstream region position check - downstream_region = (pl.col('position') > pl.col('end')) + DOWN_REGION = (pl.col('position') > pl.col('end')) - upstream_fragment = (( + UP_FRAGMENT = (( (pl.col('position') - pl.col('upstream')) / (pl.col('start') - pl.col('upstream')) ) * upstream_windows).floor() # fragment even for position == end needs to be rounded by floor # so 1e-10 is added (position is always < end) - body_fragment = (( + BODY_FRAGMENT = (( (pl.col('position') - pl.col('start')) / (pl.col('end') - pl.col('start') + 1e-10) ) * gene_windows).floor() + upstream_windows - downstream_fragment = (( + DOWN_FRAGMENT = (( (pl.col('position') - pl.col('end')) / (pl.col('downstream') - pl.col('end') + 1e-10) ) * downstream_windows).floor() + upstream_windows + gene_windows @@ -938,6 +123,26 @@ def __read_bismark_batches( read_approx = approx_batch_num(file, batch_size) read_batches = 0 + # Firstly BismarkPlot was written so there were only one sum statistic - mean. + # Sum and count of densities was calculated for further weighted mean analysis in respect to fragment size + # For backwards compatibility, for newly introduces statistics, column names are kept the same. + # Count is set to 1 and "sum" to actual statistics (e.g. median, min, e.t.c) + if sumfunc == "median": + AGG_EXPR = [pl.median("density").alias("sum"), pl.lit(1).alias("count")] + elif sumfunc == "min": + AGG_EXPR = [pl.min("density").alias("sum"), pl.lit(1).alias("count")] + elif sumfunc == "max": + AGG_EXPR = [pl.max("density").alias("sum"), pl.lit(1).alias("count")] + elif sumfunc == "geometric": + AGG_EXPR = [pl.col("density").log().mean().exp().alias("sum"), + pl.lit(1).alias("count")] + elif sumfunc == "1pgeometric": + AGG_EXPR = [(pl.col("density").log1p().mean().exp() - 1).alias("sum"), + pl.lit(1).alias("count")] + else: + AGG_EXPR = [pl.sum('density').alias('sum'), pl.count('density').alias('count')] + + # *** READING START *** # output dataframe total = None # initialize batched reader @@ -959,7 +164,7 @@ def process_batch(df: pl.DataFrame): .filter((pl.col('count_m') + pl.col('count_um') != 0)) # assign types # calculate density for each cytosine - .with_columns(df_columns) + .with_columns(DF_COLUMNS) # drop redundant columns, because individual cytosine density has already been calculated # individual counts do not matter because every cytosine is equal .drop(['count_m', 'count_um']) @@ -967,16 +172,16 @@ def process_batch(df: pl.DataFrame): .sort(['chr', 'strand', 'position']) # join with nearest .join_asof( - genome.lazy().with_columns(gene_columns), + genome.lazy().with_columns(GENE_COLUMNS), left_on='position', right_on='upstream', by=['chr', 'strand'] ) # limit by end of region .filter(pl.col('position') <= pl.col('downstream')) # calculate fragment ids .with_columns([ - pl.when(upstream_region).then(upstream_fragment) - .when(body_region).then(body_fragment) - .when(downstream_region).then(downstream_fragment) + pl.when(UP_REGION).then(UP_FRAGMENT) + .when(BODY_REGION).then(BODY_FRAGMENT) + .when(DOWN_REGION).then(DOWN_FRAGMENT) .cast(pl.Int32).alias('fragment'), pl.concat_str( pl.col("chr"), @@ -986,10 +191,7 @@ def process_batch(df: pl.DataFrame): ]) # gather fragment stats .groupby(by=['chr', 'strand', 'gene', 'context', 'id', 'fragment']) - .agg([ - pl.sum('density').alias('sum'), - pl.count('density').alias('count') - ]) + .agg(AGG_EXPR) .drop_nulls(subset=['sum']) ).collect() @@ -1114,40 +316,68 @@ def clustering(self, count_threshold = 5, dist_method="euclidean", clust_method= return Clustering(self.bismark, count_threshold, dist_method, clust_method, **self.metadata) - def line_plot(self, resolution: int = None): + def line_plot(self, resolution: int = None, stat="wmean"): """ :param resolution: Number of fragments to resize to. Keep None if not needed. :return: :class:`LinePlot`. """ bismark = self.resize(resolution) - return LinePlot(bismark.bismark, **bismark.metadata) + return LinePlot(bismark.bismark, stat=stat, **bismark.metadata) - def heat_map(self, nrow: int = 100, ncol: int = 100): + def heat_map(self, nrow: int = 100, ncol: int = 100, stat="wmean"): """ :param nrow: Number of fragments to resize to. Keep None if not needed. :param ncol: Number of columns in the resulting heat-map. :return: :class:`HeatMap`. """ bismark = self.resize(ncol) - return HeatMap(bismark.bismark, nrow, order=None, **bismark.metadata) + return HeatMap(bismark.bismark, nrow, order=None, stat=stat, **bismark.metadata) class LinePlot(BismarkBase): - def __init__(self, bismark_df: pl.DataFrame, **kwargs): + def __init__(self, bismark_df: pl.DataFrame, stat="wmean", **kwargs): """ Calculates plot data for line-plot. """ super().__init__(bismark_df, **kwargs) - self.plot_data = self.bismark.group_by(["context", "fragment"]).agg([ - pl.col("sum"), pl.col("count"), - (pl.sum("sum") / pl.sum("count")).alias("density") - ]).sort("fragment") + self.stat = stat + + plot_data = self.__calculate_plot_data(bismark_df, stat) + plot_data = self.__strand_reverse(plot_data) + self.plot_data = plot_data + + @staticmethod + def __calculate_plot_data(df: pl.DataFrame, stat): + if stat == "log": + stat_expr = (pl.col("sum") / pl.col("count")).log1p().mean().exp() - 1 + elif stat == "wlog": + stat_expr = (((pl.col("sum") / pl.col("count")).log1p() * pl.col("count")).sum() / pl.sum("count")).exp() - 1 + elif stat == "mean": + stat_expr = (pl.col("sum") / pl.col("count")).mean() + elif re.search("^q(\d+)", stat): + quantile = re.search("q(\d+)", stat).group(1) + stat_expr = (pl.col("sum") / pl.col("count")).quantile(int(quantile) / 100) + else: + stat_expr = pl.sum("sum") / pl.sum("count") + + res = ( + df + .group_by(["context", "fragment"]).agg([ + pl.col("sum"), pl.col("count"), + (stat_expr).alias("density") + ]) + .sort("fragment") + ) + return res + + def __strand_reverse(self, df: pl.DataFrame): if self.strand == '-': max_fragment = self.plot_data["fragment"].max() - self.plot_data = self.plot_data.with_columns( - (max_fragment - pl.col("fragment")).alias("fragment")) + return df.with_columns((max_fragment - pl.col("fragment")).alias("fragment")) + else: + return df @staticmethod def __interval(sum_density: list[int], sum_counts: list[int], alpha=.95): @@ -1187,12 +417,46 @@ def save_plot_rds(self, path, compress: bool = False): write_rds(path, df.to_pandas(), compress="gzip" if compress else None) + def __get_x_y(self, df, smooth, confidence): + if 0 < confidence < 1: + df = ( + df + .with_columns( + pl.struct(["sum", "count"]).map_elements( + lambda x: self.__interval(x["sum"], x["count"], confidence) + ).alias("interval") + ) + .unnest("interval") + .select(["fragment", "lower", "density", "upper"]) + ) + + data = df["density"] + + polyorder = 3 + window = smooth if smooth > polyorder else polyorder + 1 + + if smooth: + data = savgol_filter(data, window, 3, mode='nearest') + + lower, upper = None, None + data = data * 100 # convert to percents + + if 0 < confidence < 1: + upper = df["upper"].to_numpy() * 100 # convert to percents + lower = df["lower"].to_numpy() * 100 # convert to percents + + upper = savgol_filter(upper, window, 3, mode="nearest") if smooth else upper + lower = savgol_filter(lower, window, 3, mode="nearest") if smooth else lower + + return lower, data, upper + + def draw( self, fig_axes: tuple = None, smooth: int = 50, label: str = "", - confidence = 0, + confidence: int = 0, linewidth: float = 1.0, linestyle: str = '-', ) -> Figure: @@ -1214,43 +478,18 @@ def draw( contexts = self.plot_data["context"].unique().to_list() - for context in self.plot_data["context"].unique().to_list(): + for context in contexts: df = self.plot_data.filter(pl.col("context") == context) - if 0 < confidence < 1: - df = ( - df - .with_columns( - pl.struct(["sum", "count"]).map_elements( - lambda x: self.__interval(x["sum"], x["count"], confidence) - ).alias("interval") - ) - .unnest("interval") - .select(["fragment", "lower", "density", "upper"]) - ) - - data = df["density"] - - polyorder = 3 - window = smooth if smooth > polyorder else polyorder + 1 - - if smooth: - data = savgol_filter(data, window, 3, mode='nearest') + lower, data, upper = self.__get_x_y(df, smooth, confidence) x = np.arange(len(data)) - data = data * 100 # convert to percents axes.plot(x, data, label=f"{context}" if not label else f"{label}_{context}", linestyle=linestyle, linewidth=linewidth) if 0 < confidence < 1: - upper = df["upper"].to_numpy() * 100 # convert to percents - lower = df["lower"].to_numpy() * 100 # convert to percents - - upper = savgol_filter(upper, window, 3, mode="nearest") if smooth else upper - lower = savgol_filter(lower, window, 3, mode="nearest") if smooth else lower - axes.fill_between(x, lower, upper, alpha=.2) self.__add_flank_lines(axes) @@ -1262,6 +501,78 @@ def draw( return fig + def draw_plotly( + self, + figure: go.Figure = None, + smooth: int = 50, + label: str = "", + confidence: int = 0 + ): + if figure is None: + figure = go.Figure() + + contexts = self.plot_data["context"].unique().to_list() + + for context in contexts: + df = self.plot_data.filter(pl.col("context") == context) + + lower, data, upper = self.__get_x_y(df, smooth, confidence) + + x = np.arange(len(data)) + + traces = [go.Scatter(x=x, y=data, name=f"{context}" if not label else f"{label}_{context}", mode="lines")] + + if 0 < confidence < 1: + traces += [ + go.Scatter(x=x, y=upper, mode="lines", line_color = 'rgba(0,0,0,0)', showlegend=False, + name=f"{context}_{confidence}CI" if not label else f"{label}_{context}_{confidence}CI"), + go.Scatter(x=x, y=lower, mode="lines", line_color = 'rgba(0,0,0,0)', showlegend=True, + fill="tonexty", fillcolor='rgba(0, 0, 0, 0.2)', + name=f"{context}_{confidence}CI" if not label else f"{label}_{context}_{confidence}CI"), + ] + + figure.add_traces(traces) + + # self.__add_flank_lines(axes) + # + # axes.legend() + # + # axes.set_ylabel('Methylation density, %') + # axes.set_xlabel('Position') + + figure.update_layout( + xaxis_title="Position", + yaxis_title="Methylation density, %" + ) + + self.__add_flank_lines_plotly(figure) + + return figure + + def __add_flank_lines_plotly(self, figure: go.Figure): + """ + Add flank lines to the given axis (for line plot) + """ + x_ticks = [] + x_labels = [] + if self.upstream_windows > 0: + x_ticks.append(self.upstream_windows - 1) + x_labels.append('TSS') + if self.downstream_windows > 0: + x_ticks.append(self.gene_windows + self.upstream_windows) + x_labels.append('TES') + + figure.update_layout( + xaxis=dict( + tickmode='array', + tickvals=x_ticks, + ticktext=x_labels + ) + ) + + for tick in x_ticks: + figure.add_vline(x=tick, line_dash="dash", line_color="rgba(0,0,0,0.2)") + def __add_flank_lines(self, axes: plt.Axes): """ Add flank lines to the given axis (for line plot) @@ -1282,20 +593,38 @@ def __add_flank_lines(self, axes: plt.Axes): class HeatMap(BismarkBase): - def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, **kwargs): + def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, stat="wmean", **kwargs): super().__init__(bismark_df, **kwargs) + plot_data = self.__calculcate_plot_data(bismark_df, nrow, order, stat) + plot_data = self.__strand_reverse(plot_data) + + self.plot_data = plot_data + + def __calculcate_plot_data(self, df, nrow, order=None, stat="wmean"): + if stat == "log": + stat_expr = (pl.col("sum") / pl.col("count")).log1p().mean().exp() - 1 + elif stat == "wlog": + stat_expr = (((pl.col("sum") / pl.col("count")).log1p() * pl.col("count")).sum() / pl.sum("count")).exp() - 1 + elif stat == "mean": + stat_expr = (pl.col("sum") / pl.col("count")).mean() + elif re.search("^q(\d+)", stat): + quantile = re.search("q(\d+)", stat).group(1) + stat_expr = (pl.col("sum") / pl.col("count")).quantile(int(quantile) / 100) + else: + stat_expr = pl.sum("sum") / pl.sum("count") + order = ( - self.bismark.lazy() + df.lazy() .groupby(['chr', 'strand', "gene"]) .agg( - (pl.col('sum').sum() / pl.col('count').sum()).alias("order") + (stat_expr).alias("order") ) ).collect()["order"] if order is None else order # sort by rows and add row numbers hm_data = ( - self.bismark.lazy() + df.lazy() .groupby(['chr', 'strand', "gene"]) .agg( pl.col('fragment'), pl.col('sum'), pl.col('count') @@ -1315,7 +644,7 @@ def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, **kwargs): # calc sum count for row|fragment .groupby(['row', 'fragment']) .agg( - (pl.sum('sum') / pl.sum('count')).alias('density') + (stat_expr).alias('density') ) ) @@ -1346,14 +675,18 @@ def __init__(self, bismark_df: pl.DataFrame, nrow, order=None, **kwargs): ).collect() # convert to matrix - self.plot_data = np.array( + plot_data = np.array( hm_data.groupby('row', maintain_order=True).agg( pl.col('density'))['density'].to_list(), dtype=np.float32 ) + return plot_data + + def __strand_reverse(self, df: np.ndarray): if self.strand == '-': - self.plot_data = np.fliplr(self.plot_data) + return np.fliplr(df) + return df def draw( self, @@ -1399,62 +732,6 @@ def save_plot_rds(self, path, compress: bool = False): compress="gzip" if compress else None) -class BismarkFilesBase: - def __init__(self, samples, labels: list[str] = None): - self.samples = self.__check_metadata( - samples if isinstance(samples, list) else [samples]) - if samples is None: - raise Exception("Flank or gene windows number does not match!") - self.labels = [str(v) for v in list( - range(len(samples)))] if labels is None else labels - if len(self.labels) != len(self.samples): - raise Exception("Labels length doesn't match samples number") - - def save_rds(self, base_filename, compress: bool = False, merge: bool = False): - if merge: - merged = pl.concat( - [sample.bismark.lazy().with_columns(pl.lit(label)) - for sample, label in zip(self.samples, self.labels)] - ) - write_rds(base_filename, merged.to_pandas(), - compress="gzip" if compress else None) - if not merge: - for sample, label in zip(self.samples, self.labels): - sample.save_rds( - f"{remove_extension(base_filename)}_{label}.rds", compress="gzip" if compress else None) - - def save_tsv(self, base_filename, compress: bool = False, merge: bool = False): - if merge: - merged = pl.concat( - [sample.bismark.lazy().with_columns(pl.lit(label)) - for sample, label in zip(self.samples, self.labels)] - ) - if compress: - with gzip.open(base_filename + ".gz", "wb") as file: - # noinspection PyTypeChecker - merged.write_csv(file, separator="\t") - else: - merged.write_csv(base_filename, separator="\t") - if not merge: - for sample, label in zip(self.samples, self.labels): - sample.save_tsv( - f"{remove_extension(base_filename)}_{label}.rds", compress=compress) - - @staticmethod - def __check_metadata(samples: list[BismarkBase]): - upstream_check = set([sample.metadata["upstream_windows"] - for sample in samples]) - downstream_check = set( - [sample.metadata["downstream_windows"] for sample in samples]) - gene_check = set([sample.metadata["gene_windows"] - for sample in samples]) - - if len(upstream_check) == len(gene_check) == len(downstream_check) == 1: - return samples - else: - return None - - class MetageneFiles(BismarkFilesBase): """ Stores and plots multiple Bismark data. @@ -1526,17 +803,17 @@ def merge(self): else: raise Exception("Metadata for merge DataFrames does not match!") - def line_plot(self, resolution: int = None): + def line_plot(self, resolution: int = None, stat: str = "wmean"): """ :class:`LinePlot` for all files. """ - return LinePlotFiles([sample.line_plot(resolution) for sample in self.samples], self.labels) + return LinePlotFiles([sample.line_plot(resolution, stat) for sample in self.samples], self.labels) - def heat_map(self, nrow: int = 100, ncol: int = None): + def heat_map(self, nrow: int = 100, ncol: int = None, stat: str = "wmean"): """ :class:`HeatMap` for all files. """ - return HeatMapFiles([sample.heat_map(nrow, ncol) for sample in self.samples], self.labels) + return HeatMapFiles([sample.heat_map(nrow, ncol, stat) for sample in self.samples], self.labels) def violin_plot(self, fig_axes: tuple = None): """ @@ -1582,6 +859,56 @@ def box_plot(self, fig_axes: tuple = None, showfliers=False): return fig + def __dendrogram(self, groups, stat="mean"): + # get intersecting regions + gene_sets = [set(sample.bismark["gene"].to_list()) for sample in self.samples] + intersecting = list(set.intersection(*gene_sets)) + + if len(intersecting) < 1: + print("No regions with same labels were found. Exiting.") + return + + # TODO check options setter for stat (limited set of options) + # Lazy + def region_levels(bismark: pl.DataFrame, stat="mean"): + if stat == "median": + expr = pl.median("density") + elif stat == "min": + expr = pl.min("density") + elif stat == "max": + expr = pl.max("density") + else: + expr = pl.mean("density") + + levels = ( + bismark.lazy() + .with_columns((pl.col("sum") / pl.col("count")).alias("density")) + .group_by(["gene"]) + .agg(expr.alias("stat")) + .sort("gene") + ) + + return levels + + levels = [region_levels(sample.bismark, stat).rename({"stat": str(label)}) + for sample, label in zip(self.samples, self.labels)] + + data = pl.concat(levels, how="align").collect() + + matrix = data.select(pl.exclude("gene")).to_numpy() + genes = data["gene"].to_numpy() + + # get intersected + matrix = matrix[np.isin(genes, intersecting), :] + + constant = .1 + log2matrix = np.log2(matrix + constant) + + groups = np.array(groups) + logFC = np.mean(log2matrix[:, groups == 1], axis=1) - np.mean(log2matrix[:, groups == 2], axis=1) + + return + class LinePlotFiles(BismarkFilesBase): def draw( @@ -1589,7 +916,7 @@ def draw( smooth: int = 50, linewidth: float = 1.0, linestyle: str = '-', - confidence=0 + confidence: int = 0 ): plt.clf() fig, axes = plt.subplots() @@ -1599,6 +926,14 @@ def draw( return fig + def draw_plotly(self, smooth: int = 50, confidence: int = 0): + figure = go.Figure() + for lp, label in zip(self.samples, self.labels): + assert isinstance(lp, LinePlot) + lp.draw_plotly(figure, smooth, label, confidence) + + return figure + def save_plot_rds(self, base_filename, compress: bool = False, merge: bool = False): if merge: merged = pl.concat( @@ -1663,3 +998,4 @@ def save_plot_rds(self, base_filename, compress: bool = False): for sample, label in zip(self.samples, self.labels): sample.save_plot_rds(f"{remove_extension(base_filename)}_{label}.rds", compress="gzip" if compress else None) + diff --git a/src/bismarkplot/__init__.py b/src/bismarkplot/__init__.py index 4b519e1..4a44cd3 100644 --- a/src/bismarkplot/__init__.py +++ b/src/bismarkplot/__init__.py @@ -1,3 +1,5 @@ -from .BismarkPlot import Metagene, MetageneFiles, Genome, ChrLevels +from .BismarkPlot import Metagene, MetageneFiles +from .genome import Genome +from .levels import ChrLevels __version__ = 1.3 diff --git a/src/bismarkplot/base.py b/src/bismarkplot/base.py new file mode 100644 index 0000000..1375767 --- /dev/null +++ b/src/bismarkplot/base.py @@ -0,0 +1,145 @@ +import gzip + +import polars as pl +from pyreadr import write_rds + +from src.bismarkplot.utils import remove_extension + + +class BismarkBase: + """ + Base class for :class:`Metagene` and plots. + """ + + def __init__(self, bismark_df: pl.DataFrame, **kwargs): + """ + Base class for Bismark data. + + DataFrame Structure: + + +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ + | chr | strand | context | gene | fragment | sum | count | + +=================+=============+=====================+======================+==================+================+=========================================+ + | Categorical | Categorical | Categorical | Categorical | Int32 | Int32 | Int32 | + +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ + | chromosome name | strand | methylation context | position of cytosine | fragment in gene | sum methylated | count of all cytosines in this position | + +-----------------+-------------+---------------------+----------------------+------------------+----------------+-----------------------------------------+ + + + :param bismark_df: pl.DataFrame with cytosine methylation status. + :param upstream_windows: Number of upstream windows. Required. + :param gene_windows: Number of gene windows. Required. + :param downstream_windows: Number of downstream windows. Required. + :param strand: Strand if filtered. + :param context: Methylation context if filtered. + :param plot_data: Data for plotting. + """ + self.bismark: pl.DataFrame = bismark_df + + self.upstream_windows: int = kwargs.get("upstream_windows") + self.downstream_windows: int = kwargs.get("downstream_windows") + self.gene_windows: int = kwargs.get("gene_windows") + self.plot_data: pl.DataFrame = kwargs.get("plot_data") + self.context: str = kwargs.get("context") + self.strand: str = kwargs.get("strand") + + @property + def metadata(self) -> dict: + """ + :return: Bismark metadata in dict + """ + return { + "upstream_windows": self.upstream_windows, + "downstream_windows": self.downstream_windows, + "gene_windows": self.gene_windows, + "plot_data": self.plot_data, + "context": self.context, + "strand": self.strand + } + + def save_rds(self, filename, compress: bool = False): + """ + Save Bismark DataFrame in Rds. + + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. + """ + write_rds(filename, self.bismark.to_pandas(), + compress="gzip" if compress else None) + + def save_tsv(self, filename, compress=False): + """ + Save Bismark DataFrame in TSV. + + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. + """ + if compress: + with gzip.open(filename + ".gz", "wb") as file: + # noinspection PyTypeChecker + self.bismark.write_csv(file, separator="\t") + else: + self.bismark.write_csv(filename, separator="\t") + + @property + def total_windows(self): + return self.upstream_windows + self.downstream_windows + self.gene_windows + + def __len__(self): + return len(self.bismark) + + +class BismarkFilesBase: + def __init__(self, samples, labels: list[str] = None): + self.samples = self.__check_metadata( + samples if isinstance(samples, list) else [samples]) + if samples is None: + raise Exception("Flank or gene windows number does not match!") + self.labels = [str(v) for v in list( + range(len(samples)))] if labels is None else labels + if len(self.labels) != len(self.samples): + raise Exception("Labels length doesn't match samples number") + + def save_rds(self, base_filename, compress: bool = False, merge: bool = False): + if merge: + merged = pl.concat( + [sample.bismark.lazy().with_columns(pl.lit(label)) + for sample, label in zip(self.samples, self.labels)] + ) + write_rds(base_filename, merged.to_pandas(), + compress="gzip" if compress else None) + if not merge: + for sample, label in zip(self.samples, self.labels): + sample.save_rds( + f"{remove_extension(base_filename)}_{label}.rds", compress="gzip" if compress else None) + + def save_tsv(self, base_filename, compress: bool = False, merge: bool = False): + if merge: + merged = pl.concat( + [sample.bismark.lazy().with_columns(pl.lit(label)) + for sample, label in zip(self.samples, self.labels)] + ) + if compress: + with gzip.open(base_filename + ".gz", "wb") as file: + # noinspection PyTypeChecker + merged.write_csv(file, separator="\t") + else: + merged.write_csv(base_filename, separator="\t") + if not merge: + for sample, label in zip(self.samples, self.labels): + sample.save_tsv( + f"{remove_extension(base_filename)}_{label}.rds", compress=compress) + + @staticmethod + def __check_metadata(samples: list[BismarkBase]): + upstream_check = set([sample.metadata["upstream_windows"] + for sample in samples]) + downstream_check = set( + [sample.metadata["downstream_windows"] for sample in samples]) + gene_check = set([sample.metadata["gene_windows"] + for sample in samples]) + + if len(upstream_check) == len(gene_check) == len(downstream_check) == 1: + return samples + else: + return None diff --git a/src/bismarkplot/clusters.py b/src/bismarkplot/clusters.py new file mode 100644 index 0000000..bdfde1c --- /dev/null +++ b/src/bismarkplot/clusters.py @@ -0,0 +1,270 @@ +import gzip +from collections import Counter +from functools import cache + +import numpy as np +import polars as pl +from dynamicTreeCut import cutreeHybrid +from matplotlib import pyplot as plt, colormaps +from matplotlib.figure import Figure +from pyreadr import write_rds +from scipy.cluster.hierarchy import linkage, leaves_list +from scipy.spatial.distance import pdist + +from .base import BismarkBase +from .utils import hm_flank_lines + + +class Clustering(BismarkBase): + """ + Class for clustering genes within sample + """ + + def __init__(self, bismark_df: pl.DataFrame, count_threshold=5, dist_method="euclidean", clust_method="average", **kwargs): + """ + :param bismark_df: :class:polars.DataFrame with genes data + :param count_threshold: Minimum counts per fragment + :param dist_method: Method for evaluating distance + :param clust_method: Method for hierarchical clustering + """ + super().__init__(bismark_df, **kwargs) + + if self.bismark["fragment"].max() > 50: + print(f"WARNING: too many windows ({self.bismark['fragment'].max() + 1}), clusterisation may take very long time") + + grouped = ( + self.bismark.lazy() + .with_columns((pl.col("sum") / pl.col("count")).alias("density")) + .group_by(["chr", "strand", "gene", "context"]) + .agg([pl.col("density"), + pl.col("fragment"), + pl.sum("count").alias("gene_count"), + pl.count("fragment").alias("count")]) + ).collect() + + print(f"Starting with:\t{len(grouped)}") + + by_count = grouped.filter(pl.col("gene_count") > (count_threshold * pl.col("count"))) + + print(f"Left after count theshold filtration:\t{len(by_count)}") + + by_count = by_count.filter(pl.col("count") == self.total_windows) + + print(f"Left after empty windows filtration:\t{len(by_count)}") + + if len(by_count) == 0: + print("All genes have empty windows, exiting") + raise ValueError("All genes have empty windows") + + by_count = by_count.explode(["density", "fragment"]).drop(["gene_count", "count"]).fill_nan(0) + + unpivot = by_count.pivot( + index=["chr", "strand", "gene"], + values="density", + columns="fragment", + aggregate_function="sum" + ).select( + ["chr", "strand", "gene"] + list(map(str, range(self.total_windows))) + ).with_columns( + pl.col("gene").alias("label") + ) + + self.gene_labels = unpivot.with_columns(pl.col("label").cast(pl.Utf8))["label"].to_numpy() + self.matrix = unpivot[list(map(str, range(self.total_windows)))].to_numpy() + + self.gene_labels = self.gene_labels[~np.isnan(self.matrix).any(axis=1)] + self.matrix = self.matrix[~np.isnan(self.matrix).any(axis=1), :] + + # dist matrix + print("Distances calculation") + self.dist = pdist(self.matrix, metric=dist_method) + # linkage matrix + print("Linkage calculation and minimizing distances") + self.linkage = linkage(self.dist, method=clust_method, optimal_ordering=True) + + self.order = leaves_list(self.linkage) + + def modules(self, **kwargs): + return Modules(self.gene_labels, self.matrix, self.linkage, self.dist, + windows={ + key: self.metadata[key] for key in ["upstream_windows", "gene_windows", "downstream_windows"] + }, + **kwargs) + + # TODO: rewrite save_rds, save_tsv + + def draw( + self, + fig_axes: tuple = None, + title: str = None + ) -> Figure: + """ + Draws heat-map on given :class:`matplotlib.Axes` or makes them itself. + + :param fig_axes: Tuple with (fig, axes) from :meth:`matplotlib.plt.subplots`. + :param title: Title of the plot. + :return: + """ + if fig_axes is None: + plt.clf() + fig, axes = plt.subplots() + else: + fig, axes = fig_axes + + vmin = 0 + vmax = np.max(np.array(self.plot_data)) + + image = axes.imshow( + self.matrix[self.order, :], + interpolation="nearest", aspect='auto', + cmap=colormaps['cividis'], + vmin=vmin, vmax=vmax + ) + axes.set_title(title) + axes.set_xlabel('Position') + axes.set_ylabel('') + + hm_flank_lines(axes, self.upstream_windows, self.gene_windows, self.downstream_windows) + + axes.set_yticks([]) + plt.colorbar(image, ax=axes, label='Methylation density') + + return fig + + +class Modules: + """ + Class for module construction and visualization of clustered genes + """ + def __init__(self, labels: list, matrix: np.ndarray, linkage, distance, windows, **kwargs): + if not len(labels) == len(matrix): + raise ValueError("Length of labels and methylation matrix labels don't match") + + self.labels, self.matrix = labels, matrix + self.linkage, self.distance = linkage, distance + + self.__windows = windows + + self.tree = self.__dynamic_tree_cut(**kwargs) + + def recalculate(self, **kwargs): + """ + Recalculate tree with another params + + :param kwargs: any kwargs to cutreeHybrid from dynamicTreeCut + """ + self.tree = self.__dynamic_tree_cut(**kwargs) + + @cache + def __dynamic_tree_cut(self, **kwargs): + return cutreeHybrid(self.linkage, self.distance, **kwargs) + + @property + def __format__table(self) -> pl.DataFrame: + return pl.DataFrame( + {"gene_labels": list(self.labels)} | + {key: list(self.tree[key]) for key in ["labels", "cores", "smallLabels", "onBranch"]} + ) + + def save_rds(self, filename, compress: bool = False): + """ + Save module data in Rds. + + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. + """ + write_rds(filename, self.__format__table.to_pandas(), + compress="gzip" if compress else None) + + def save_tsv(self, filename, compress=False): + """ + Save module data in TSV. + + :param filename: Path for file. + :param compress: Whether to compress to gzip or not. + """ + if compress: + with gzip.open(filename + ".gz", "wb") as file: + # noinspection PyTypeChecker + self.__format__table.write_csv(file, separator="\t") + else: + self.__format__table.write_csv(filename, separator="\t") + + def draw( + self, + fig_axes: tuple = None, + title: str = None, + show_labels=True, + show_size=False + ) -> Figure: + """ + Method for visualization of moduled genes. Every row of heat-map represents an average methylation + profile of genes of the module. + + :param fig_axes: tuple(Fig, Axes) to plot + :param title: Title of the plot + :param show_labels: Enable/disable module number labels + :param show_size: Enable/disable module size labels (in brackets) + """ + + me_matrix, me_labels = [], [] + label_stats = Counter(self.tree["labels"]) + + # iterate every label + for label in label_stats.keys(): + # select genes from module + module_genes = self.tree["labels"] == label + # append mean module pattern + me_matrix.append(self.matrix[module_genes, :].mean(axis=0)) + + me_labels.append(f"{label} ({label_stats[label]})" if show_size else str(label)) + + me_matrix, me_labels = np.stack(me_matrix), np.stack(me_labels) + # sort matrix to minimize distances between modules + order = leaves_list(linkage( + y=pdist(me_matrix, metric="euclidean"), + method="average", + optimal_ordering=True + )) + + me_matrix, me_labels = me_matrix[order, :], me_labels[order] + + if fig_axes is None: + plt.clf() + fig, axes = plt.subplots() + else: + fig, axes = fig_axes + + vmin = 0 + vmax = np.max(np.array(me_matrix)) + + image = axes.imshow( + me_matrix, + interpolation="nearest", aspect='auto', + cmap=colormaps['cividis'], + vmin=vmin, vmax=vmax + ) + + if show_labels: + axes.set_yticks(np.arange(.5, len(me_labels), 1)) + axes.set_yticklabels(me_labels) + else: + axes.set_yticks([]) + + axes.set_title(title) + axes.set_xlabel('Position') + axes.set_ylabel('Module') + # axes.yaxis.tick_right() + + hm_flank_lines( + axes, + self.__windows["upstream_windows"], + self.__windows["gene_windows"], + self.__windows["downstream_windows"], + ) + + plt.colorbar(image, ax=axes, label='Methylation density', + # orientation="horizontal", location="top" + ) + + return fig diff --git a/src/bismarkplot/console_chrs.py b/src/bismarkplot/console_chrs.py index 9321880..8e4d058 100644 --- a/src/bismarkplot/console_chrs.py +++ b/src/bismarkplot/console_chrs.py @@ -26,7 +26,7 @@ def main(): args = parser.parse_args() try: - from .BismarkPlot import ChrLevels + from src.bismarkplot import ChrLevels import matplotlib.pyplot as plt chr = ChrLevels.from_file( diff --git a/src/bismarkplot/console_metagene.py b/src/bismarkplot/console_metagene.py index f7d7308..c1b9123 100644 --- a/src/bismarkplot/console_metagene.py +++ b/src/bismarkplot/console_metagene.py @@ -44,7 +44,8 @@ def main(): exit() try: - from .BismarkPlot import MetageneFiles, Genome + from .BismarkPlot import MetageneFiles + from src.bismarkplot import Genome genome = Genome.from_gff( file=args.genome ) diff --git a/src/bismarkplot/genome.py b/src/bismarkplot/genome.py new file mode 100644 index 0000000..d72328b --- /dev/null +++ b/src/bismarkplot/genome.py @@ -0,0 +1,309 @@ +import polars as pl + + +class Genome: + def __init__(self, genome: pl.LazyFrame): + """ + Class for storing and manipulating genome DataFrame. + + Genome Dataframe columns: + + +------+--------+-------+-------+----------+------------+ + | chr | strand | start | end | upstream | downstream | + +======+========+=======+=======+==========+============+ + | Utf8 | Utf8 | Int32 | Int32 | Int32 | Int32 | + +------+--------+-------+-------+----------+------------+ + + :param genome: :class:`pl.LazyFrame` with genome data. + """ + self.genome = genome + + @classmethod + def from_custom(cls, + file: str, + chr_col: int = 0, + start_col: int = 1, + end_col: int = 2, + id_col: int = None, + strand_col: int = 5, + type_col: int = None, + comment_char: str = "#", + has_header: bool = False): + + if sum([val is None for val in [chr_col, strand_col, start_col, end_col]]) > 0: + raise Exception("All position columns need to be specified!") + + genes = ( + pl.scan_csv( + file, + comment_char=comment_char, + has_header=has_header, + separator='\t' + ) + ) + cols = genes.columns + select_cols = [ + pl.col(cols[chr_col]).alias("chr").cast(pl.Utf8), + pl.col(cols[type_col]).alias("type") if type_col is not None else pl.lit(None).alias("type"), + pl.col(cols[start_col]).alias("start").cast(pl.Int32), + pl.col(cols[end_col]).alias("end").cast(pl.Int32), + pl.col(cols[strand_col]).alias("strand"), + pl.col(cols[id_col]).alias("id") if id_col is not None else pl.lit("").alias("id"), + ] + + genes = genes.with_columns(select_cols).drop(cols) + + print(f"Genome read from {file}") + return cls(genes) + + @classmethod + def from_gff(cls, file: str): + """ + Constructor with parameters for default gff file. + + :param file: path to genome.gff. + """ + + id_regex = "^ID=([^;]+)" + + genome = cls.from_custom(file, + 0, 3, 4, 8, 6, 2, + "#", False) + + genome.genome = genome.genome.with_columns( + pl.col("id").str.extract(id_regex) + ) + return genome + + def all(self, min_length: int = 4000, flank_length: int = 2000) -> pl.DataFrame: + genes = self.__filter_genes( + self.genome, None, min_length, flank_length) + genes = self.__trim_genes(genes, flank_length).collect() + return self.__check_empty(genes) + + def gene_body(self, min_length: int = 4000, flank_length: int = 2000) -> pl.DataFrame: + """ + Filter type == gene from gff. + + :param min_length: minimal length of genes. + :param flank_length: length of the flanking region. + :return: :class:`pl.LazyFrame` with genes and their flanking regions. + """ + genes = self.__filter_genes( + self.genome, 'gene', min_length, flank_length) + genes = self.__trim_genes(genes, flank_length).collect() + return self.__check_empty(genes) + + def exon(self, min_length: int = 100) -> pl.DataFrame: + """ + Filter type == exon from gff. + + :param min_length: minimal length of exons. + :return: :class:`pl.LazyFrame` with exons. + """ + flank_length = 0 + genes = self.__filter_genes( + self.genome, 'exon', min_length, flank_length) + genes = self.__trim_genes(genes, flank_length).collect() + return self.__check_empty(genes) + + def cds(self, min_length: int = 100) -> pl.DataFrame: + """ + Filter type == CDS from gff. + + :param min_length: minimal length of CDS. + :return: :class:`pl.LazyFrame` with CDS. + """ + flank_length = 0 + genes = self.__filter_genes( + self.genome, 'CDS', min_length, flank_length) + genes = self.__trim_genes(genes, flank_length).collect() + return self.__check_empty(genes) + + def near_TSS(self, min_length: int = 4000, flank_length: int = 2000): + """ + Get region near TSS - upstream and same length from TSS. + + :param min_length: minimal length of genes. + :param flank_length: length of the flanking region. + :return: :class:`pl.LazyFrame` with genes and their flanking regions. + """ + + # decided not to use this + ''' + upstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('upstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('upstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) + ) + ''' + upstream_length = flank_length + + gene_type = "gene" + genes = self.__filter_genes( + self.genome, gene_type, min_length, flank_length) + genes = ( + genes + .groupby(['chr', 'strand'], maintain_order=True).agg([ + pl.col('start'), + # upstream shift + (pl.col('start').shift(-1) - pl.col('end')).shift(1) + .fill_null(flank_length) + .alias('upstream'), + pl.col('id') + ]) + .explode(['start', 'upstream', 'id']) + .with_columns([ + (pl.col('start') - upstream_length).alias('upstream'), + (pl.col("start") + flank_length).alias("end") + ]) + .with_columns(pl.col("end").alias("downstream")) + ).collect() + + return self.__check_empty(genes) + + def near_TES(self, min_length: int = 4000, flank_length: int = 2000): + """ + Get region near TES - downstream and same length from TES. + + :param min_length: minimal length of genes. + :param flank_length: length of the flanking region. + :return: :class:`pl.LazyFrame` with genes and their flanking regions. + """ + + # decided not to use this + ''' + downstream_length = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('downstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('downstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) + ) + ''' + downstream_length = flank_length + + gene_type = "gene" + genes = self.__filter_genes( + self.genome, gene_type, min_length, flank_length) + genes = ( + genes + .groupby(['chr', 'strand'], maintain_order=True).agg([ + pl.col('end'), + # downstream shift + (pl.col('start').shift(-1) - pl.col('end')) + .fill_null(flank_length) + .alias('downstream'), + pl.col('id') + ]) + .explode(['end', 'downstream', 'id']) + .with_columns([ + (pl.col('end') + downstream_length).alias('downstream'), + (pl.col("end") - flank_length).alias("start") + ]) + .with_columns(pl.col("start").alias("upstream")) + ).collect() + + return self.__check_empty(genes) + + def other(self, gene_type: str, min_length: int = 1000, flank_length: int = 100) -> pl.DataFrame: + """ + Filter by selected type. + + :param gene_type: selected type from gff. Cases need to match. + :param min_length: minimal length of genes. + :param flank_length: length of the flanking region. + :return: :class:`pl.LazyFrame` with genes and their flanking regions. + """ + genes = self.__filter_genes( + self.genome, gene_type, min_length, flank_length) + genes = self.__trim_genes(genes, flank_length).collect() + return self.__check_empty(genes) + + @staticmethod + def __filter_genes(genes, gene_type, min_length, flank_length): + if gene_type is not None: + genes = genes.filter(pl.col('type') == gene_type).drop('type') + else: + genes = genes.drop("type") + + # filter genes, which start < flank_length + if flank_length > 0: + genes = genes.filter(pl.col('start') > flank_length) + # filter genes which don't pass length threshold + if min_length > 0: + genes = genes.filter((pl.col('end') - pl.col('start')) > min_length) + + return genes + + @staticmethod + def __trim_genes(genes, flank_length) -> pl.LazyFrame: + # upstream shift + # calculates length to previous gene on same chr_strand + length_before = (pl.col('start').shift(-1) - pl.col('end')).shift(1).fill_null(flank_length) + # downstream shift + # calculates length to next gene on same chr_strand + length_after = (pl.col('start').shift(-1) - pl.col('end')).fill_null(flank_length) + + # decided not to use this conditions + ''' + upstream_length_conditioned = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('upstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('upstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('upstream') - (pl.col('upstream') % 2)) // 2) + ) + + downstream_length_conditioned = ( + # when before length is enough + # we set upstream length to specified + pl.when(pl.col('downstream') >= flank_length).then(flank_length) + # when genes are intersecting (current start < previous end) + # we don't take this as upstream region + .when(pl.col('downstream') < 0).then(0) + # when length between genes is not enough for full specified length + # we divide it into half + .otherwise((pl.col('downstream') - pl.col('downstream') % 2) // 2) + ) + ''' + + return ( + genes + .groupby(['chr', 'strand'], maintain_order=True).agg([ + pl.col('start'), + pl.col('end'), + length_before.alias('upstream'), + length_after.alias('downstream'), + pl.col('id') + ]) + .explode(['start', 'end', 'upstream', 'downstream', 'id']) + .with_columns([ + # calculates length of region + (pl.col('start') - flank_length).alias('upstream'), + # calculates length of region + (pl.col('end') + flank_length).alias('downstream') + ]) + ) + + @staticmethod + def __check_empty(genes): + if len(genes) > 0: + return genes + else: + raise Exception( + "Genome DataFrame is empty. Are you sure input file is valid?") diff --git a/src/bismarkplot/levels.py b/src/bismarkplot/levels.py new file mode 100644 index 0000000..cba3cd0 --- /dev/null +++ b/src/bismarkplot/levels.py @@ -0,0 +1,177 @@ +from multiprocessing import cpu_count + +import numpy as np +import polars as pl +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from pyreadr import write_rds +from scipy.signal import savgol_filter + +from src.bismarkplot.utils import approx_batch_num + + +class ChrLevels: + def __init__(self, df: pl.DataFrame) -> None: + self.bismark = df + self.plot_data = ( + df + .sort(["chr", "window"]) + .with_row_count("fragment") + .group_by(["chr", "fragment"], maintain_order=True) + .agg([pl.sum("sum"), pl.sum("count")]) + .with_columns((pl.col("sum") / pl.col("count")).alias("density")) + ) + + @classmethod + def from_file( + cls, + file: str, + chr_min_length = 10**6, + window_length: int = 10**6, + batch_size: int = 10 ** 6, + cpu: int = cpu_count() + ): + """ + Initialize ChrLevels with CX_report file + + :param file: Path to file + :param chr_min_length: Minimum length of chromosome to be analyzed + :param window_length: Length of windows in bp + :param cpu: How many cores to use. Uses every physical core by default + :param batch_size: Number of rows to read by one CPU core + """ + cpu = cpu if cpu is not None else cpu_count() + + bismark = pl.read_csv_batched( + file, + separator='\t', has_header=False, + new_columns=['chr', 'position', 'strand', + 'count_m', 'count_um', 'context'], + columns=[0, 1, 2, 3, 4, 5], + batch_size=batch_size, + n_threads=cpu + ) + read_approx = approx_batch_num(file, batch_size) + read_batches = 0 + + total = None + + batches = bismark.next_batches(cpu) + print(f"Reading from {file}") + while batches: + for df in batches: + df = ( + df.lazy() + .filter((pl.col('count_m') + pl.col('count_um') != 0)) + .group_by(["strand", "chr"]) + .agg([ + pl.col("context"), + (pl.col("position") / window_length).floor().alias("window").cast(pl.Int32), + ((pl.col('count_m')) / (pl.col('count_m') + pl.col('count_um'))).alias('density').cast(pl.Float32), + (pl.max("position") - pl.min("position")).alias("length") + ]) + .filter(pl.col("length") > chr_min_length) + .explode(["context", "window", "density"]) + .group_by(by=['chr', 'strand', 'context', 'window']) + .agg([ + pl.sum('density').alias('sum'), + pl.count('density').alias('count') + ]) + .drop_nulls(subset=['sum']) + ).collect() + if total is None and len(df) == 0: + raise Exception( + "Error reading Bismark file. Check format or genome. No joins on first batch.") + elif total is None: + total = df + else: + total = total.extend(df) + + read_batches += 1 + print( + f"\tRead {read_batches}/{read_approx} batch | Total size - {round(total.estimated_size('mb'), 1)}Mb RAM", end="\r") + batches = bismark.next_batches(cpu) + print("DONE") + + return cls(total) + + def save_plot_rds(self, path, compress: bool = False): + """ + Saves plot data in a rds DataFrame with columns: + + +----------+---------+ + | fragment | density | + +==========+=========+ + | Int | Float | + +----------+---------+ + """ + write_rds(path, self.plot_data.to_pandas(), + compress="gzip" if compress else None) + + def filter(self, context: str = None, strand: str = None, chr: str = None): + """ + :param context: Methylation context (CG, CHG, CHH) to filter (only one). + :param strand: Strand to filter (+ or -). + :param chr: Chromosome name to filter. + :return: Filtered :class:`Bismark`. + """ + context_filter = self.bismark["context"] == context if context is not None else True + strand_filter = self.bismark["strand"] == strand if strand is not None else True + chr_filter = self.bismark["chr"] == chr if chr is not None else True + + if context_filter is None and strand_filter is None and chr_filter is None: + return self + else: + return self.__class__(self.bismark.filter(context_filter & strand_filter & chr_filter)) + + def draw( + self, + fig_axes: tuple = None, + smooth: int = 10, + label: str = None, + linewidth: float = 1.0, + linestyle: str = '-', + ) -> Figure: + + if fig_axes is None: + fig, axes = plt.subplots() + else: + fig, axes = fig_axes + + ticks_data = self.plot_data.group_by("chr", maintain_order = True).agg(pl.min("fragment")) + + x_lines = ticks_data["fragment"].to_numpy() + x_lines = np.append(x_lines, self.plot_data["fragment"].max()) + + x_ticks = (x_lines[1:] + x_lines[:-1]) // 2 + + # get middle ticks + + x_labels = ticks_data["chr"].to_list() + + data = self.plot_data["density"].to_numpy() + + polyorder = 3 + window = smooth if smooth > polyorder else polyorder + 1 + + if smooth: + data = savgol_filter(data, window, 3, mode='nearest') + + x = np.arange(len(data)) + data = data * 100 # convert to percents + axes.plot(x, data, label=label, + linestyle=linestyle, linewidth=linewidth) + + axes.set_xticks(x_ticks) + axes.set_xticklabels(x_labels) + + axes.legend() + axes.set_ylabel('Methylation density, %') + axes.set_xlabel('Position') + + for tick in x_lines: + axes.axvline(x=tick, linestyle='--', color='k', alpha=.1) + + fig.set_size_inches(12, 5) + + return fig diff --git a/src/bismarkplot/utils.py b/src/bismarkplot/utils.py new file mode 100644 index 0000000..68578ff --- /dev/null +++ b/src/bismarkplot/utils.py @@ -0,0 +1,40 @@ +import re +from os.path import getsize + +import numpy as np +from matplotlib.axes import Axes + + +def remove_extension(path): + re.sub("\.[^./]+$", "", path) + + +def approx_batch_num(path, batch_size, check_lines=1000): + size = getsize(path) + + length = 0 + with open(path, "rb") as file: + for _ in range(check_lines): + length += len(file.readline()) + + return round(np.ceil(size / (length / check_lines * batch_size))) + + +def hm_flank_lines(axes: Axes, upstream_windows: int, gene_windows: int, downstream_windows: int): + """ + Add flank lines to the given axis (for line plot) + """ + x_ticks = [] + x_labels = [] + if upstream_windows > 0: + x_ticks.append(upstream_windows - .5) + x_labels.append('TSS') + if downstream_windows > 0: + x_ticks.append(gene_windows + downstream_windows - .5) + x_labels.append('TES') + + if x_ticks and x_labels: + axes.set_xticks(x_ticks) + axes.set_xticklabels(x_labels) + for tick in x_ticks: + axes.axvline(x=tick, linestyle='--', color='k', alpha=.3)