Skip to content

Commit

Permalink
Fix RNATables in case of very large collecitons
Browse files Browse the repository at this point in the history
  • Loading branch information
JureZmrzlikar committed Jun 26, 2023
1 parent a286fc6 commit ea13f64
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
9 changes: 9 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ Change Log
All notable changes to this project are documented in this file.


==========
Unreleased
==========

Fixed
-----
- Fix ``RNATables`` in case of very large collection


===================
18.0.0 - 2023-05-18
===================
Expand Down
27 changes: 13 additions & 14 deletions src/resdk/tables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,18 @@ def _data(self) -> List[Data]:
:return: list of Data objects
"""
data = []
sample_ids, repeated_sample_ids = set(), set()
sample2data = {}
repeated_sample_ids = set()
for datum in self.collection.data.filter(
type=self.process_type,
status="OK",
ordering="-created",
fields=self.DATA_FIELDS,
):
if datum.sample.id in sample_ids:
).iterate():
# We are using iterate to prevent 504 Bad Gateways
# This means that data is given from oldest to newest
if datum.sample.id in sample2data:
repeated_sample_ids.add(datum.sample.id)
continue
sample_ids.add(datum.sample.id)
data.append(datum)
sample2data[datum.sample.id] = datum

if repeated_sample_ids:
repeated = ", ".join(map(str, repeated_sample_ids))
Expand All @@ -190,7 +189,7 @@ def _data(self) -> List[Data]:
UserWarning,
)

return data
return list(sample2data.values())

@property
@lru_cache()
Expand Down Expand Up @@ -256,9 +255,8 @@ def _qc_version(self) -> str:
type="data:multiqc",
status="OK",
entity__isnull=False,
ordering="id",
fields=["id", "entity__id"],
)
).iterate()
]
if not mqc_ids:
raise ValueError(
Expand Down Expand Up @@ -533,7 +531,7 @@ async def _download_data(self, data_type: str) -> pd.DataFrame:
:param data_type: data type
:return: table with data, features in columns, samples in rows
"""
samples_data = []
df = None
for i in self.tqdm(
range(0, len(self._data), EXP_ASYNC_CHUNK_SIZE),
desc="Downloading data",
Expand All @@ -559,8 +557,9 @@ async def _download_data(self, data_type: str) -> pd.DataFrame:
for url, id_ in urls_ids
]
data = await asyncio.gather(*futures)
samples_data.extend(data)
data = pd.concat(data, axis=1)
df = pd.concat([df, data], axis=1)

df = pd.concat(samples_data, axis=1).T.sort_index().sort_index(axis=1)
df = df.T.sort_index().sort_index(axis=1)
df.index.name = "sample_id"
return df
2 changes: 2 additions & 0 deletions src/resdk/tables/rna.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ def _parse_file(self, file_obj, sample_id, data_type):
sample_data = pd.read_csv(file_obj, sep="\t", compression="gzip")
sample_data = sample_data.set_index("Gene")["Expression"]
sample_data.name = sample_id
# Optimize memory usage, float32 and int32 will suffice.
sample_data = sample_data.astype("int32" if data_type == self.RC else "float32")
return sample_data

async def _download_data(self, data_type: str) -> pd.DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/tables/e2e_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_rc(self):
self.assertIn(39000, self.ct.rc.index)
self.assertIn("ENSG00000000003", self.ct.rc.columns)
self.assertEqual(self.ct.rc.iloc[0, 0], 792)
self.assertIsInstance(self.ct.rc.iloc[0, 0], np.int64)
self.assertIsInstance(self.ct.rc.iloc[0, 0], np.int32)

def test_exp(self):
self.assertEqual(self.ct.exp.shape, (8, 58487))
self.assertIn(39000, self.ct.exp.index)
self.assertIn("ENSG00000000003", self.ct.exp.columns)
self.assertAlmostEqual(self.ct.exp.iloc[0, 0], 19.447467, places=3)
self.assertIsInstance(self.ct.exp.iloc[0, 0], np.float64)
self.assertIsInstance(self.ct.exp.iloc[0, 0], np.float32)

def test_consistent_index(self):
self.assertTrue(all(self.ct.exp.index == self.ct.meta.index))
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setUp(self):
self.collection.slug = "slug"
self.collection.name = "Name"
self.collection.samples.filter = self.web_request([self.sample])
self.collection.data.filter = self.web_request([self.data])
self.collection.data.filter().iterate = self.web_request([self.data])
self.collection.resolwe = self.resolwe

self.relation = MagicMock()
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_heterogeneous_collections(self):
data2.id = 12345
data2.process.slug = "process-slug2"
data2.output.__getitem__.side_effect = {"source": "ENSEMBL"}.__getitem__
self.collection.data.filter = self.web_request([self.data, data2])
self.collection.data.filter().iterate = self.web_request([self.data, data2])

with self.assertRaisesRegex(ValueError, r"Expressions of all samples.*"):
RNATables(self.collection)
Expand All @@ -112,7 +112,7 @@ def test_heterogeneous_collections(self):
data2.id = 12345
data2.process.slug = "process-slug"
data2.output.__getitem__.side_effect = {"source": "GENCODE"}.__getitem__
self.collection.data.filter = self.web_request([self.data, data2])
self.collection.data.filter().iterate = self.web_request([self.data, data2])

with self.assertRaisesRegex(ValueError, r"Alignment of all samples.*"):
RNATables(self.collection)
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_metadata_version(self):
version = ct1._metadata_version

def test_qc_version(self):
self.collection.data.filter = self.web_request([self.data])
self.collection.data.filter().iterate = self.web_request([self.data])

ct = RNATables(self.collection)
version = ct._qc_version
Expand All @@ -256,7 +256,7 @@ def test_qc_version(self):
version = ct._qc_version
self.assertTrue(time() - t < 0.1)

self.collection.data.filter = self.web_request([])
self.collection.data.filter().iterate = self.web_request([])
ct1 = RNATables(self.collection)
with self.assertRaises(ValueError):
version = ct1._qc_version
Expand All @@ -271,7 +271,7 @@ def test_data_version(self):
version = ct._data_version
self.assertTrue(time() - t < 0.1)

self.collection.data.filter = MagicMock(return_value=[])
self.collection.data.filter().iterate = MagicMock(return_value=[])
ct = RNATables(self.collection)
with self.assertRaises(ValueError):
version = ct._data_version
Expand Down

0 comments on commit ea13f64

Please sign in to comment.