Skip to content

Commit

Permalink
feat: coalescing the datasets (opentargets#932)
Browse files Browse the repository at this point in the history
Co-authored-by: Szymon Szyszkowski <ss60@mib117351s.internal.sanger.ac.uk>
  • Loading branch information
project-defiant and Szymon Szyszkowski authored Nov 27, 2024
1 parent 4837a4b commit 7b3bfad
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 18 deletions.
13 changes: 10 additions & 3 deletions src/gentropy/biosample_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Step to generate biosample index dataset."""

from __future__ import annotations

from gentropy.common.session import Session
Expand Down Expand Up @@ -28,10 +29,16 @@ def __init__(
efo_input_path (str): Input efo dataset path.
biosample_index_path (str): Output gene index dataset path.
"""
cell_ontology_index = extract_ontology_from_json(cell_ontology_input_path, session.spark)
cell_ontology_index = extract_ontology_from_json(
cell_ontology_input_path, session.spark
)
uberon_index = extract_ontology_from_json(uberon_input_path, session.spark)
efo_index = extract_ontology_from_json(efo_input_path, session.spark).retain_rows_with_ancestor_id(["CL_0000000"])
efo_index = extract_ontology_from_json(
efo_input_path, session.spark
).retain_rows_with_ancestor_id(["CL_0000000"])

biosample_index = cell_ontology_index.merge_indices([uberon_index, efo_index])

biosample_index.df.write.mode(session.write_mode).parquet(biosample_index_path)
biosample_index.df.coalesce(session.output_partitions).write.mode(
session.write_mode
).parquet(biosample_index_path)
6 changes: 3 additions & 3 deletions src/gentropy/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def __init__(
coloc = partial(coloc, **colocalisation_method_params)
colocalisation_results = coloc(overlaps)
# Load
colocalisation_results.df.write.mode(session.write_mode).parquet(
f"{coloc_path}/{colocalisation_method.lower()}"
)
colocalisation_results.df.coalesce(session.output_partitions).write.mode(
session.write_mode
).parquet(f"{coloc_path}/{colocalisation_method.lower()}")

@classmethod
def _get_colocalisation_class(
Expand Down
3 changes: 3 additions & 0 deletions src/gentropy/common/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__( # noqa: D107
hail_home: str | None = None,
start_hail: bool = False,
extended_spark_conf: dict[str, str] | None = None,
output_partitions: int = 200,
) -> None:
"""Initialises spark session and logger.
Expand All @@ -34,6 +35,7 @@ def __init__( # noqa: D107
hail_home (str | None): Path to Hail installation. Defaults to None.
start_hail (bool): Whether to start Hail. Defaults to False.
extended_spark_conf (dict[str, str] | None): Extended Spark configuration. Defaults to None.
output_partitions (int): Number of partitions for output datasets. Defaults to 200.
"""
merged_conf = self._create_merged_config(
start_hail, hail_home, extended_spark_conf
Expand All @@ -53,6 +55,7 @@ def __init__( # noqa: D107
self.start_hail = start_hail
if start_hail:
hl.init(sc=self.spark.sparkContext, log="/dev/null")
self.output_partitions = output_partitions

def _default_config(self: Session) -> SparkConf:
"""Default spark configuration.
Expand Down
1 change: 1 addition & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SessionConfig:
spark_uri: str = "local[*]"
hail_home: str = os.path.dirname(hail_location)
extended_spark_conf: dict[str, str] | None = field(default_factory=dict[str, str])
output_partitions: int = 200
_target_: str = "gentropy.common.session.Session"


Expand Down
5 changes: 4 additions & 1 deletion src/gentropy/gene_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Step to generate gene index dataset."""

from __future__ import annotations

from gentropy.common.session import Session
Expand Down Expand Up @@ -28,4 +29,6 @@ def __init__(
# Transform
gene_index = OpenTargetsTarget.as_gene_index(platform_target)
# Load
gene_index.df.write.mode(session.write_mode).parquet(gene_index_path)
gene_index.df.coalesce(session.output_partitions).write.mode(
session.write_mode
).parquet(gene_index_path)
8 changes: 4 additions & 4 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def __init__(

# Valid study locus partitioned to simplify the finding of overlaps
study_locus_with_qc.valid_rows(invalid_qc_reasons).df.repartitionByRange(
"chromosome", "position"
session.output_partitions, "chromosome", "position"
).sortWithinPartitions("chromosome", "position").write.mode(
session.write_mode
).parquet(valid_study_locus_path)

# Invalid study locus
study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode(
session.write_mode
).parquet(invalid_study_locus_path)
study_locus_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce(
session.output_partitions
).write.mode(session.write_mode).parquet(invalid_study_locus_path)
12 changes: 6 additions & 6 deletions src/gentropy/study_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def __init__(
) # Flagging QTL studies with invalid biosamples
).persist() # we will need this for 2 types of outputs

study_index_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.write.mode(
session.write_mode
).parquet(invalid_study_index_path)
study_index_with_qc.valid_rows(invalid_qc_reasons, invalid=True).df.coalesce(
session.output_partitions
).write.mode(session.write_mode).parquet(invalid_study_index_path)

study_index_with_qc.valid_rows(invalid_qc_reasons).df.write.mode(
session.write_mode
).parquet(valid_study_index_path)
study_index_with_qc.valid_rows(invalid_qc_reasons).df.coalesce(
session.output_partitions
).write.mode(session.write_mode).parquet(valid_study_index_path)
4 changes: 3 additions & 1 deletion src/gentropy/variant_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(
variant_index = variant_index.add_annotation(annotations)

(
variant_index.df.repartitionByRange("chromosome", "position")
variant_index.df.repartitionByRange(
session.output_partitions, "chromosome", "position"
)
.sortWithinPartitions("chromosome", "position")
.write.mode(session.write_mode)
.parquet(variant_index_path)
Expand Down

0 comments on commit 7b3bfad

Please sign in to comment.