-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils_math.py
99 lines (96 loc) · 5.83 KB
/
utils_math.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from numpy import genfromtxt
import numpy as np
def Th_comp_matmul(Ar, Ai, Br, Bi): # Complex matmul pytorch function ########
if Ar.ndim == 3 and Br.ndim == 3:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=2), torch.cat((Ai, Ar), dim=2)), dim=1)
b_th = torch.cat((torch.cat((Br, -Bi), dim=2), torch.cat((Bi, Br), dim=2)), dim=1)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, 0:int(c_th.shape[1] / 2), 0:int(c_th.shape[2] / 2)]
c_th_i = c_th[:, int(c_th.shape[1] / 2):, 0:int(c_th.shape[2] / 2)]
elif Ar.ndim == 2 and Br.ndim == 2:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=1), torch.cat((Ai, Ar), dim=1)), dim=0)
b_th = torch.cat((torch.cat((Br, -Bi), dim=1), torch.cat((Bi, Br), dim=1)), dim=0)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[0:int(c_th.shape[0] / 2), 0:int(c_th.shape[1] / 2)]
c_th_i = c_th[int(c_th.shape[0] / 2):, 0:int(c_th.shape[1] / 2)]
elif Ar.ndim == 4 and Br.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Ar.ndim == 5 and Br.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=4), torch.cat((Ai, Ar), dim=4)), dim=3)
b_th = torch.cat((torch.cat((Br, -Bi), dim=4), torch.cat((Bi, Br), dim=4)), dim=3)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
elif Ar.ndim * Br.ndim == 12:
if Ar.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=2), torch.cat((Bi, Br), dim=2)), dim=1)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Br.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=2), torch.cat((Ai, Ar), dim=2)), dim=1)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Ar.ndim * Br.ndim == 20:
if Ar.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=4), torch.cat((Ai, Ar), dim=4)), dim=3)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
elif Br.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=4), torch.cat((Bi, Br), dim=4)), dim=3)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
return c_th_r, c_th_i
def Th_inv(Ar, Ai): # Complex inverse pytorch function ########
Ar_inv = torch.inverse(Ar + torch.matmul(torch.matmul(Ai, torch.inverse(Ar)), Ai))
Ai_inv = - torch.matmul(torch.matmul(torch.inverse(Ar), Ai), Ar_inv)
'''or'''
# Ai_inv = - torch.inverse(Ai + torch.matmul(torch.matmul(Ar, torch.inverse(Ai)), Ar))
# Ar_inv = - torch.matmul(torch.matmul(torch.inverse(Ai), Ar), Ai_inv)
return Ar_inv, Ai_inv
def Th_pinv(Ar, Ai): # Complex inverse pytorch function ########
if Ar.ndim == 2:
if Ar.shape[0] < Ar.shape[1]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.T, -Ai.T)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.T, -Ai.T, Ar_inv, Ai_inv)
elif Ar.shape[0] > Ar.shape[1]:
Tempr, Tempi = Th_comp_matmul(Ar.T, -Ai.T, Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.T, -Ai.T)
elif Ar.shape[0] == Ar.shape[1]:
return Th_inv(Ar, Ai)
elif Ar.ndim == 3:
if Ar.shape[1] < Ar.shape[2]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1))
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1), Ar_inv, Ai_inv)
elif Ar.shape[1] > Ar.shape[2]:
Tempr, Tempi = Th_comp_matmul(Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1), Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1))
elif Ar.shape[1] == Ar.shape[2]:
return Th_inv(Ar, Ai)
elif Ar.ndim == 4:
if Ar.shape[2] < Ar.shape[3]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2))
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2), Ar_inv, Ai_inv)
elif Ar.shape[2] > Ar.shape[3]:
Tempr, Tempi = Th_comp_matmul(Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2), Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2))
elif Ar.shape[2] == Ar.shape[3]:
return Th_inv(Ar, Ai)