-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_parallel.py
75 lines (57 loc) · 2.14 KB
/
run_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Run a function using parallel processing
Blank Project
Image and Sound Processing Lab - Politecnico di Milano
Paolo Bestagini
"""
import time
from functools import partial
from multiprocessing import Pool, cpu_count
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
def fun(x, y):
"""
Simple function that sums two numbers
:param x: first input number
:param y: second input number
:return: sum of the two numbers
"""
return x+y
def main():
# Input parameters
x_list = np.arange(10)
y = 5
# Define the number of cores to use for parallel processing
num_cpu = cpu_count() // 2 # Use one quarter of the available cores
# We can only loop over one parameter (e.g., x), thus we need to fix the other function parameters (e.g., y=5)
fun_part = partial(fun, y=y)
# Initialize the pool of cores
pool = Pool(num_cpu)
# Evaluate the function in series
t = time.time()
result_series = []
# for x in x_list: # Without wait-bar
for x in tqdm(x_list, total=len(x_list), desc='Serial'): # With wait-bar
result = fun_part(x)
result_series.append(result)
t_series = time.time() - t
# Evaluate the function in parallel
t = time.time()
# result_parallel = pool.map(fun_part, x_list) # Without wait-bar
result_parallel = list(tqdm(pool.imap(fun_part, x_list), total=len(x_list), desc='Parallel')) # With wait-bar
t_parallel = time.time() - t
# Close the pool
pool.close()
# Evaluate the function with multi-threading
# BEST TO USE for I/O operations on data
t = time.time()
with ThreadPoolExecutor(num_cpu) as p:
results_mthread = list(tqdm(p.map(fun_part, x_list), total=len(x_list), desc='Multi-threading')) # With wait-bar
t_mthread = time.time() - t
# Print results
print('Serial results: {} [{:.2f} ms]'.format(result_series, t_series*1000))
print('Parallel results: {} [{:.2f} ms]'.format(result_parallel, t_parallel*1000))
print('Multi-threading results: {} [{:.2f} ms]'.format(results_mthread, t_mthread*1000))
if __name__ == '__main__':
main()