diff --git a/src/gentropy/biosample_index.py b/src/gentropy/biosample_index.py index e0b5e9b10..a6e8b5223 100644 --- a/src/gentropy/biosample_index.py +++ b/src/gentropy/biosample_index.py @@ -1,4 +1,5 @@ """Step to generate biosample index dataset.""" + from __future__ import annotations from gentropy.common.session import Session @@ -28,10 +29,16 @@ def __init__( efo_input_path (str): Input efo dataset path. biosample_index_path (str): Output gene index dataset path. """ - cell_ontology_index = extract_ontology_from_json(cell_ontology_input_path, session.spark) + cell_ontology_index = extract_ontology_from_json( + cell_ontology_input_path, session.spark + ) uberon_index = extract_ontology_from_json(uberon_input_path, session.spark) - efo_index = extract_ontology_from_json(efo_input_path, session.spark).retain_rows_with_ancestor_id(["CL_0000000"]) + efo_index = extract_ontology_from_json( + efo_input_path, session.spark + ).retain_rows_with_ancestor_id(["CL_0000000"]) biosample_index = cell_ontology_index.merge_indices([uberon_index, efo_index]) - biosample_index.df.write.mode(session.write_mode).parquet(biosample_index_path) + biosample_index.df.coalesce(session.output_partitions).write.mode( + session.write_mode + ).parquet(biosample_index_path) diff --git a/src/gentropy/colocalisation.py b/src/gentropy/colocalisation.py index 9682a8ed9..6a4568397 100644 --- a/src/gentropy/colocalisation.py +++ b/src/gentropy/colocalisation.py @@ -70,9 +70,9 @@ def __init__( coloc = partial(coloc, **colocalisation_method_params) colocalisation_results = coloc(overlaps) # Load - colocalisation_results.df.write.mode(session.write_mode).parquet( - f"{coloc_path}/{colocalisation_method.lower()}" - ) + colocalisation_results.df.coalesce(session.output_partitions).write.mode( + session.write_mode + ).parquet(f"{coloc_path}/{colocalisation_method.lower()}") @classmethod def _get_colocalisation_class( diff --git a/src/gentropy/common/session.py b/src/gentropy/common/session.py index 297903629..3a8ad4af7 100644 --- a/src/gentropy/common/session.py +++ b/src/gentropy/common/session.py @@ -24,6 +24,7 @@ def __init__( # noqa: D107 hail_home: str | None = None, start_hail: bool = False, extended_spark_conf: dict[str, str] | None = None, + output_partitions: int = 200, ) -> None: """Initialises spark session and logger. @@ -34,6 +35,7 @@ def __init__( # noqa: D107 hail_home (str | None): Path to Hail installation. Defaults to None. start_hail (bool): Whether to start Hail. Defaults to False. extended_spark_conf (dict[str, str] | None): Extended Spark configuration. Defaults to None. + output_partitions (int): Number of partitions for output datasets. Defaults to 200. """ merged_conf = self._create_merged_config( start_hail, hail_home, extended_spark_conf @@ -53,6 +55,7 @@ def __init__( # noqa: D107 self.start_hail = start_hail if start_hail: hl.init(sc=self.spark.sparkContext, log="/dev/null") + self.output_partitions = output_partitions def _default_config(self: Session) -> SparkConf: """Default spark configuration. diff --git a/src/gentropy/config.py b/src/gentropy/config.py index b32647acf..65fdb5897 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -18,6 +18,7 @@ class SessionConfig: spark_uri: str = "local[*]" hail_home: str = os.path.dirname(hail_location) extended_spark_conf: dict[str, str] | None = field(default_factory=dict[str, str]) + output_partitions: int = 200 _target_: str = "gentropy.common.session.Session" diff --git a/src/gentropy/gene_index.py b/src/gentropy/gene_index.py index ad8e95083..0a317d077 100644 --- a/src/gentropy/gene_index.py +++ b/src/gentropy/gene_index.py @@ -1,4 +1,5 @@ """Step to generate gene index dataset.""" + from __future__ import annotations from gentropy.common.session import Session @@ -28,4 +29,6 @@ def __init__( # Transform gene_index = OpenTargetsTarget.as_gene_index(platform_target) # Load - gene_index.df.write.mode(session.write_mode).parquet(gene_index_path) + gene_index.df.coalesce(session.output_partitions).write.mode( + session.write_mode + ).parquet(gene_index_path) diff --git a/src/gentropy/study_locus_validation.py b/src/gentropy/study_locus_validation.py index cf36a4389..ce2201f80 100644 --- a/src/gentropy/study_locus_validation.py +++ b/src/gentropy/study_locus_validation.py @@ -58,12 +58,12 @@ def __init__( # Valid study locus partitioned to simplify the finding of overlaps study_locus_with_qc.valid_rows(invalid_qc_reasons).df.repartitionByRange( - "chromosome", "position" + session.output_partitions, "chromosome", "position" ).sortWithinPartitions("chromosome", "position").write.mode( session.write_mode ).parquet(valid_study_locus_path) # Invalid study locus - study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode( - session.write_mode - ).parquet(invalid_study_locus_path) + study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce( + session.output_partitions + ).write.mode(session.write_mode).parquet(invalid_study_locus_path) diff --git a/src/gentropy/study_validation.py b/src/gentropy/study_validation.py index 08f601f1e..3d2fdd060 100644 --- a/src/gentropy/study_validation.py +++ b/src/gentropy/study_validation.py @@ -71,10 +71,10 @@ def __init__( ) # Flagging QTL studies with invalid biosamples ).persist() # we will need this for 2 types of outputs - study_index_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode( - session.write_mode - ).parquet(invalid_study_index_path) + study_index_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce( + session.output_partitions + ).write.mode(session.write_mode).parquet(invalid_study_index_path) - study_index_with_qc.valid_rows(invalid_qc_reasons).df.write.mode( - session.write_mode - ).parquet(valid_study_index_path) + study_index_with_qc.valid_rows(invalid_qc_reasons).df.coalesce( + session.output_partitions + ).write.mode(session.write_mode).parquet(valid_study_index_path) diff --git a/src/gentropy/variant_index.py b/src/gentropy/variant_index.py index 9eac684b2..ae7efa5c4 100644 --- a/src/gentropy/variant_index.py +++ b/src/gentropy/variant_index.py @@ -56,7 +56,9 @@ def __init__( variant_index = variant_index.add_annotation(annotations) ( - variant_index.df.repartitionByRange("chromosome", "position") + variant_index.df.repartitionByRange( + session.output_partitions, "chromosome", "position" + ) .sortWithinPartitions("chromosome", "position") .write.mode(session.write_mode) .parquet(variant_index_path)