-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpre_processing.py
374 lines (337 loc) · 14.1 KB
/
pre_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
from typing import Final
from typing import List
from typing import Dict
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm
from glob import glob
import numpy as np
import logging
import random
import pickle
import time
import sys
import os
from Bio import SeqIO
from Bio import Entrez
from urllib.error import HTTPError
from urllib.error import URLError
import xmltodict
import pandas as pd
import pytaxonkit
from Bio.SeqRecord import SeqRecord
from utils import SEPARATOR
from utils import setup_logger
DATASET_PATH: Final = os.path.join(os.getcwd(), 'data', 'raw')
DICT_DATASET_PATH: Final = os.path.join(DATASET_PATH, 'dataset.pickle')
DF_DATASET_PATH: Final = os.path.join(DATASET_PATH, 'dataset.csv')
def split_fasta_file_on_processes(fasta_files: List[str], n_proc: int) -> List[List[str]]:
n_files: int = len(fasta_files)
n_files_for_process: int = n_files // n_proc
rest: int = n_files % n_proc
fasta_files_for_each_process: List[List[str]] = []
rest_added: int = 0
for i in range(n_proc):
start: int = i * n_files_for_process + rest_added
if rest > i:
end: int = start + n_files_for_process + 1
fasta_files_for_each_process.append(fasta_files[start:end])
rest_added += 1
else:
end: int = start + n_files_for_process
fasta_files_for_each_process.append(fasta_files[start:end])
return fasta_files_for_each_process
def split_dataset_on_processes(
fasta_files_for_each_process: List[List[str]],
dataset: Dict[str, Dict[str, str]]) -> List[Dict[str, Dict[str, str]]]:
# init list of n dicts for n process
dataset_for_each_process: List[Dict[str, Dict[str, str]]] = []
# for each files for process i
for fasta_files_path in fasta_files_for_each_process:
dataset_for_process_i: Dict[str, Dict[str, str]] = {}
# for each file
for fasta_file_path in fasta_files_path:
# delegate file to process i
dataset_for_process_i[fasta_file_path] = dataset[fasta_file_path]
dataset_for_each_process.append(dataset_for_process_i)
return dataset_for_each_process
def extract_tax_id(fasta_files_path: List[str], logger: logging.Logger) -> Dict[str, Dict[str, str]]:
# init dataset
dataset: Dict[str, Dict[str, str]] = {}
# for each fasta file
for fasta_file_path in fasta_files_path:
# map all id in tax id
mapping: Dict[str, str] = {}
# read fasta file
fasta_file = SeqIO.parse(open(fasta_file_path), 'fasta')
# for each read in file
for record in fasta_file:
logger.info(f'Request for id: {record.id}')
status_code: int = 0
handle = None
while status_code != 200:
try:
# extract tax id
handle = Entrez.efetch(
db="nuccore",
id=record.id,
retmode="xml",
rettype="fasta"
)
status_code = handle.getcode()
except HTTPError:
pass
except URLError:
status_code = 200
logger.info(f'url error: skip {fasta_file_path}')
if handle is None:
continue
response_dict = xmltodict.parse(handle)
tax_id = response_dict['TSeqSet']['TSeq']['TSeq_taxid']
# save id -> tax_id information
mapping[record.id] = tax_id
handle.close()
# save all result in dataset
dataset[fasta_file_path] = mapping
# log file completed
logger.info(f'{fasta_file_path} computed')
return dataset
def extract_lineage(dataset: Dict[str, Dict[str, str]], logger: logging.Logger) -> pd.DataFrame:
# init local dataframe
df = pd.DataFrame()
for fasta_file_path in dataset.keys():
for seq_id in dataset[fasta_file_path].keys():
# get tax_id
tax_id = dataset[fasta_file_path][seq_id]
# get lineage by tax_id
df_result = pytaxonkit.lineage([tax_id])
# get value of lineage
try:
full_lineage_ranks: List[str] = df_result['FullLineageRanks'][0].split(';')
full_lineage: List[str] = df_result['FullLineage'][0].split(';')
# create row of result
values = {'File Path': fasta_file_path, 'ID': seq_id, 'TaxID': tax_id}
values.update(dict(zip(full_lineage_ranks, full_lineage)))
# merge result in local dataframe
df = pd.concat([df, pd.DataFrame([values])])
except AttributeError:
continue
# log file completed
logger.info(f'{fasta_file_path} computed')
return df
def generate_dataframe(len_read: int,
len_overlap: int,
dataset: pd.DataFrame,
dataset_status,
taxonomy_level: str,
indexes: List[int],
cluster: bool = False) -> pd.DataFrame:
for index in tqdm(indexes,
total=len(indexes),
desc=f'Group all sequence by taxonomy level'):
entry = dataset.loc[index]
id_sequence: str = entry['ID']
taxonomy_value: str = entry[taxonomy_level]
fasta_file = SeqIO.parse(open(entry['File Path']), 'fasta')
for sequence_fasta in fasta_file:
if sequence_fasta.id == id_sequence:
record = SeqRecord(
sequence_fasta.seq,
id=f'{id_sequence}_{taxonomy_value}',
)
with open(os.path.join(DATASET_PATH, f'{taxonomy_value}.fasta'), 'a') as output_handle:
SeqIO.write(record, output_handle, 'fasta')
break
for taxonomy_value in tqdm(dataset_status.keys(),
total=len(dataset_status.keys()),
desc=f'Generating {len_read} bp read with overlap size {len_overlap}...'):
tmp_fasta_path: str = os.path.join(DATASET_PATH, f'{taxonomy_value}.fasta')
gt_fasta_path: str = os.path.join(DATASET_PATH, f'{taxonomy_value}_gt.fasta')
if len_overlap > 0:
command: str = f'gt shredder ' \
f'-minlength {len_read} ' \
f'-maxlength {len_read} ' \
f'-overlap {len_overlap} ' \
f'-clipdesc yes ' \
f'{tmp_fasta_path} >> {gt_fasta_path}'
else:
command: str = f'gt shredder ' \
f'-minlength {len_read} ' \
f'-maxlength {len_read} ' \
f'-clipdesc yes ' \
f'{tmp_fasta_path} >> {gt_fasta_path}'
os.system(command)
# remove generated files
for file_ext in ['', '.sds', '.ois', '.md5', '.esq', '.des', '.ssp']:
os.system(f'rm {tmp_fasta_path}{file_ext}')
# apply under-sampling with clustering
cluster_path: str = os.path.join(DATASET_PATH, 'cluster')
columns: Final = ['id_gene', taxonomy_level, 'start', 'end', 'sequence']
df: pd.DataFrame = pd.DataFrame(columns=columns)
for taxonomy_value in tqdm(dataset_status.keys(),
total=len(dataset_status.keys()),
desc="Merging of all reads into a single dataset..."):
gt_fasta_path: str = os.path.join(DATASET_PATH, f'{taxonomy_value}_gt.fasta')
if cluster:
os.system(f'cd-hit-est '
f'-T 0 '
f'-d 0 '
f'-i {gt_fasta_path} '
f'-o {cluster_path} '
f'> /dev/null')
os.system(f'rm {gt_fasta_path}')
fasta_file = SeqIO.parse(open(cluster_path), 'fasta')
else:
fasta_file = SeqIO.parse(open(gt_fasta_path), 'fasta')
for sequence_fasta in fasta_file:
values: str = sequence_fasta.id
values: List = values.split('_')
id_gene: str = f'{values[0]}_{values[1]}'
taxonomy_value: str = values[2]
start: int = int(values[3])
end: int = start + int(values[4])
sequence = sequence_fasta.seq
row_df: pd.DataFrame = pd.DataFrame(
[[
id_gene,
taxonomy_value,
start,
end,
sequence
]],
columns=columns
)
df = pd.concat([df, row_df])
if cluster:
os.system(f'rm {cluster_path} {cluster_path}.clstr')
else:
os.system(f'rm {gt_fasta_path}')
return df.sample(frac=1)
def pre_processing(
taxonomy_level: str,
len_read: int = 250,
len_overlap: int = 200,
train_size: float = 0.6):
# setup logger
logger = setup_logger('logger', os.path.join(os.getcwd(), 'data/logger.log'))
# get all fasta files path
fasta_files: List[str] = glob(os.path.join(DATASET_PATH, '*.fasta'))
# global dict of dataset
dataset: Dict[str, Dict[str, str]] = {}
# if tax id phase is completed, skip it
if os.path.exists(DICT_DATASET_PATH):
# load dataset with pickle
with open(DICT_DATASET_PATH, 'rb') as handle:
dataset = pickle.load(handle)
else:
# log start phase
logger.info('Extract tax id phase')
logger.info(SEPARATOR)
# config Entrez email
Entrez.email = sys.argv[1]
# split fasta files on cpus
fasta_files_for_each_process: List[List[str]] = split_fasta_file_on_processes(fasta_files, os.cpu_count())
# create a process pool that uses all cpus
start = time.time()
with Pool(os.cpu_count()) as pool:
results = pool.imap(partial(extract_tax_id, logger=logger), fasta_files_for_each_process)
for local_dataset in results:
dataset.update(local_dataset)
# log finish phase
logger.info(f'\nPhase completed in {(time.time()) - start}')
logger.info(SEPARATOR)
# save dataset with pickle
with open(DICT_DATASET_PATH, 'wb') as handle:
pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
if not os.path.exists(DF_DATASET_PATH):
# init number of process
n_proc: int = 5
# log start phase
logger.info('Extract taxonomy levels phase')
logger.info(SEPARATOR)
start = time.time()
# split dataset for each process
# split fasta files on cpus
fasta_files_for_each_process: List[List[str]] = split_fasta_file_on_processes(fasta_files, n_proc)
dataset_for_each_process: List[Dict[str, Dict[str, str]]] = split_dataset_on_processes(
fasta_files_for_each_process,
dataset
)
# create global df and split work on workers
dataset_csv = pd.DataFrame()
with Pool(n_proc) as pool:
results = pool.imap(partial(extract_lineage, logger=logger), dataset_for_each_process)
# merge each local dataset
for local_dataset in results:
dataset_csv = pd.concat([dataset_csv, local_dataset])
# save global df
dataset_csv.to_csv(DF_DATASET_PATH, index=False)
# log finish phase
logger.info(f'\nPhase completed in {(time.time()) - start}')
logger.info(SEPARATOR)
# load dataset with pickle
dataset_csv: pd.DataFrame = pd.read_csv(DF_DATASET_PATH)
# group data by taxonomy_level
dataset_csv = dataset_csv[dataset_csv[taxonomy_level].notnull()]
dataset_status = dataset_csv.groupby(taxonomy_level)[taxonomy_level].count()
# split idx in train, val and test set
train_idx = []
val_idx = []
test_idx = []
for label in dataset_status.keys():
idx: List[int] = dataset_csv.index[dataset_csv[taxonomy_level] == label].values
random.shuffle(idx)
len_train = int(len(idx) * train_size)
len_val = int((len(idx) - len_train) / 2)
# split idx
train_idx = np.append(train_idx, idx[:len_train])
val_idx = np.append(val_idx, idx[len_train:len_train + len_val])
test_idx = np.append(test_idx, idx[len_train + len_val:])
# shuffle train, test and val
random.shuffle(train_idx)
random.shuffle(val_idx)
random.shuffle(test_idx)
# generate train dataframe
logger.info(f'Generate training set...')
train_df: pd.DataFrame = generate_dataframe(
len_read=len_read,
len_overlap=len_overlap,
dataset=dataset_csv,
dataset_status=dataset_status,
taxonomy_level=taxonomy_level,
indexes=train_idx,
cluster=True
)
train_dataset_path = os.path.join(DATASET_PATH, f'{taxonomy_level}_train.csv')
train_df.to_csv(train_dataset_path, index=False)
logger.info(f'{train_dataset_path} generated!')
# generate val set
logger.info(f'Generate validation set...')
val_df: pd.DataFrame = generate_dataframe(
len_read=len_read,
len_overlap=0,
dataset=dataset_csv,
dataset_status=dataset_status,
taxonomy_level=taxonomy_level,
indexes=val_idx,
cluster=False
)
val_dataset_path: Final = os.path.join(DATASET_PATH, f'{taxonomy_level}_val.csv')
val_df.to_csv(val_dataset_path, index=False)
logger.info(f'{val_dataset_path} generated!')
# generate test set
logger.info(f'Generate testing set...')
test_df: pd.DataFrame = generate_dataframe(
len_read=len_read,
len_overlap=0,
dataset=dataset_csv,
dataset_status=dataset_status,
taxonomy_level=taxonomy_level,
indexes=test_idx,
cluster=False
)
test_dataset_path: Final = os.path.join(DATASET_PATH, f'{taxonomy_level}_test.csv')
test_df.to_csv(test_dataset_path, index=False)
logger.info(f'{test_dataset_path} generated!')
if __name__ == '__main__':
pre_processing(taxonomy_level='order')