-
Notifications
You must be signed in to change notification settings - Fork 90
/
IPMatrix.py
176 lines (150 loc) · 5.41 KB
/
IPMatrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#############################################################################
######### class Matrix with display() function for image visualization
#############################################################################
class Matrix:
"""
Represents a rectangular matrix with n rows and m columns.
"""
def __init__(self, n, m, val=0):
"""
Create an n-by-m matrix of val's.
Inner representation: list of lists (rows)
"""
assert n > 0 and m > 0
#self.rows = [[val]*m]*n #why this is bad?
self.rows = [[val]*m for i in range(n)]
def dim(self):
return len(self.rows), len(self.rows[0])
#def __repr__(self):
# if len(self.rows)>10 or len(self.rows[0])>10:
# return "Matrix too large, specify submatrix"
# return "<Matrix {}>".format(self.rows)
def __eq__(self, other):
return isinstance(other, Matrix) and self.rows == other.rows
# cell/sub-matrix access/assignment
####################################
def __getitem__(self, ij): #ij is a tuple (i,j). Allows m[i,j] instead m[i][j]
i,j = ij
if isinstance(i, int) and isinstance(j, int):
return self.rows[i][j]
elif isinstance(i, slice) and isinstance(j, slice):
M = Matrix(1,1) # to be overwritten
M.rows = [row[j] for row in self.rows[i]]
return M
else:
return NotImplemented
def __setitem__(self, ij, val): #ij is a tuple (i,j). Allows m[i,j] instead m[i][j]
i,j = ij
if isinstance(i,int) and isinstance(j,int):
assert isinstance(val, (int, float, complex))
self.rows[i][j] = val
elif isinstance(i,slice) and isinstance(j,slice):
assert isinstance(val, Matrix)
n,m = val.dim()
s_rows = self.rows[i]
assert len(s_rows) == n and len(s_rows[0][j]) == m
for s_row, v_row in zip(s_rows,val.rows):
s_row[j] = v_row
else:
return NotImplemented
def copy(self):
n,m = self.dim()
M = Matrix(n,m)
for i in range(n):
for j in range(m):
M[i,j] = self[i,j]
return M
# arithmetic operations
########################
def entrywise_op(self, other, op):
if not isinstance(other, Matrix):
return NotImplemented
assert self.dim() == other.dim()
n,m = self.dim()
M = Matrix(n,m)
for i in range(n):
for j in range(m):
M[i,j] = op(self[i,j], other[i,j])
return M
def __add__(self, other):
return self.entrywise_op(other,lambda x,y:x+y)
def __sub__(self, other):
return self.entrywise_op(other,lambda x,y:x-y)
def __neg__(self):
n,m = self.dim()
return Matrix(n,m) - self
def __mul__(self, other):
if isinstance(other, Matrix):
return self.multiply_by_matrix(other)
elif isinstance(other, (int, float, complex)):
return self.multiply_by_scalar(other)
else:
return NotImplemented
__rmul__ = __mul__
def multiply_by_scalar(self, val):
n,m = self.dim()
return self.entrywise_op(Matrix(n,m,val), lambda x,y :x*y)
###a more efficient version, memory-wise.
## n,m = self.dim()
## M = Matrix(n,m)
## for i in range(n):
## for j in range(m):
## M[i,j] = self[i,j] * val
## return M
def multiply_by_matrix(self, other):
assert isinstance(other, Matrix)
n,m = self.dim()
n2,m2 = other.dim()
assert m == n2
M = Matrix(n,m2)
for i in range(n):
for j in range(m2):
M[i,j] = sum(self[i,k] * other[k,j] for k in range(m))
return M
# Input/output
###############
def save(self, filename):
f = open(filename, 'w')
n,m = self.dim()
print(n,m, file=f)
for row in self.rows:
for e in row:
print(e, end=" ", file=f)
print("",file=f) #newline
f.close()
@staticmethod
def load(filename):
f = open(filename)
line = f.readline()
n,m = [int(x) for x in line.split()]
result = Matrix(n,m)
for i in range(n):
line = f.readline()
row = [int(x) for x in line.split()]
assert len(row) == m
result.rows[i] = row
return result
# This allows to show PIL images and Matrix in IPython Notebook
from io import BytesIO
from IPython.core import display
from PIL import Image
def display_pil_image(im):
"""Displayhook function for PIL Images, rendered as PNG."""
b=BytesIO()
img.save(b,format='png')
data = b.getvalue()
display.display_png(data, raw=True)
def display_matrix_image(mat):
img = Image.new('L',size=(mat.dim()[1],mat.dim()[0]))
for i in range(mat.dim()[0]):
for j in range(mat.dim()[1]):
img.putpixel((j,i),mat[i,j])
b=BytesIO()
img.save(b,format='png')
data = b.getvalue()
display.display_png(data, raw=True)
# register display func with PNG formatter:
png_formatter = get_ipython().display_formatter.formatters['image/png']
dpi = png_formatter.for_type(Image.Image, display_pil_image)
png_formatter = get_ipython().display_formatter.formatters['image/png']
dpi = png_formatter.for_type(Matrix, display_matrix_image)