-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtensor_functions.py
281 lines (208 loc) · 8.2 KB
/
tensor_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""
This module provides additional tensor functions that are not yet implemented in
the RoughPy core library.
The implementation of the functions in this module are far from optimal, but
should serve as both a temporary implementation and a demonstration of how one
can build on top of the RoughPy core library.
"""
import functools
from collections import defaultdict
from typing import Any, Union, TypeVar
import roughpy as rp
def tensor_word_factory(basis):
"""
Create a factory function that constructs tensor words objects.
Since the tensor words are specific to their basis, the basis is needed to
construct the words. This function creates a factory from the correct basis
to make tensor words that correspond. The arguments to the factory function
are just a sequence of letters in the same order as they will appear in the
tensor word.
:param basis: RoughPy tensor basis object.
:return: function from a sequence of letters to tensor words
"""
width = basis.width
depth = basis.depth
# noinspection PyArgumentList
def factory(*args):
return rp.TensorKey(*args, width=width, depth=depth)
return factory
class TensorTensorProduct:
"""
External tensor product of two free tensors (or shuffle tensors).
This is an intermediate container that is used to implement some of the
tensor functions such as Log.
"""
data: 'dict[tuple[rp.TensorKey], Any]'
ctx: 'rp.Context'
def __init__(self, data, ctx=None):
if isinstance(data, (tuple, list)):
assert len(data) == 2
assert isinstance(data[0], (rp.FreeTensor, rp.ShuffleTensor))
assert isinstance(data[1], (rp.FreeTensor, rp.ShuffleTensor))
if ctx is not None:
self.ctx = ctx
assert data[0].context == ctx
assert data[1].context == ctx
else:
self.ctx = data[0].context
assert self.ctx == data[2].context
self.data = odata = {}
for lhs in data[0]:
for rhs in data[1]:
odata[(lhs.key(), rhs.key())] = lhs.value() * rhs.value()
# self.data = {k: v for k, v in odata.items() if v != 0}
elif isinstance(data, (dict, defaultdict)):
self.data = {k: v for k, v in data.items() if v != 0}
assert ctx is not None
self.ctx = ctx
def __str__(self):
return " ".join([
'{',
*(f"{v}{k}" for k, v in sorted(self.data.items(),
key=lambda t: tuple(
map(lambda r: tuple(
r.to_letters()), t[0])))),
'}'
])
def __repr__(self):
return " ".join([
'{',
*(f"{v}{k}" for k, v in sorted(self.data.items(),
key=lambda t: tuple(
map(lambda r: tuple(
r.to_letters()), t[0])))),
'}'
])
def __mul__(self, scalar):
if isinstance(scalar, (int, float, rp.Scalar)):
if scalar == 0:
data = {}
else:
data = {k: v * scalar for k, v in self.data.items()}
return TensorTensorProduct(data, self.ctx)
return NotImplemented
def __add__(self, other):
if isinstance(other, TensorTensorProduct):
assert self.ctx == other.ctx
new_data = defaultdict(lambda: 0)
# Do this to force a deep copy of the values from self
for k, v in self.data.items():
new_data[k] += v
for k, v in other.data.items():
new_data[k] += v
return TensorTensorProduct(
{k: v for k, v in new_data.items() if v != 0}, self.ctx)
return NotImplemented
def add_scal_prod(self, other, scalar):
my_data = self.data
for k, v in other.data.items():
val = v * scalar
if k in self.data:
my_data[k] += val
else:
my_data[k] = val
# print("asp", other, scalar, self)
self.data = {k: v for k, v in self.data.items() if v != 0}
return self
def add_scal_div(self, other, scalar):
my_data = self.data
for k, v in other.data.items():
if k in self.data:
my_data[k] += v / scalar
else:
my_data[k] = v / scalar
return self
def sub_scal_div(self, other, scalar):
my_data = self.data
for k, v in other.data.items():
if k in self.data:
my_data[k] -= v / scalar
else:
my_data[k] = -v / scalar
return self
def _concat_product(a, b):
out = defaultdict(lambda: 0)
for k1, v1 in a.data.items():
for k2, v2 in b.data.items():
out[tuple(i * j for i, j in zip(k1, k2))] += v1 * v2
return TensorTensorProduct({k: v for k, v in out.items() if v != 0}, a.ctx)
# noinspection PyUnresolvedReferences
def _adjoint_of_word(word: rp.TensorKey, ctx: rp.Context) \
-> TensorTensorProduct:
word_factory = tensor_word_factory(word.basis())
letters = word.to_letters()
if not letters:
return TensorTensorProduct({(word_factory(),) * 2: 1}, ctx)
letters_adj = [
TensorTensorProduct({(word_factory(letter), word_factory()): 1,
(word_factory(), word_factory(letter)): 1}, ctx)
for letter in word.to_letters()]
return functools.reduce(_concat_product, letters_adj)
def _adjoint_of_shuffle(
tensor: Union[rp.FreeTensor, rp.ShuffleTensor]) -> TensorTensorProduct:
# noinspection PyUnresolvedReferences
ctx = tensor.context
out = TensorTensorProduct(defaultdict(lambda: 0), ctx)
for item in tensor:
out.add_scal_prod(_adjoint_of_word(item.key(), ctx), item.value())
return out
def _concatenate(a: TensorTensorProduct, otype=rp.FreeTensor):
"""
Perform an elementwise reduction induced on A \\otimes B by the
concatenation of words.
:param a: External tensor product of tensors
:return: tensor obtained by reducing all pairs of words
"""
data = defaultdict(lambda: 0)
for (l, r), v in a.data.items():
data[l * r] += v
# noinspection PyArgumentList
result = otype(data, ctx=a.ctx)
return result
def _tensor_product_functions(f, g):
# noinspection PyArgumentList
def function_product(x: TensorTensorProduct) -> TensorTensorProduct:
ctx = x.ctx
result = TensorTensorProduct({}, ctx)
for (k1, k2), v in x.data.items():
tk1 = f(rp.FreeTensor((k1, 1), ctx=ctx))
tk2 = g(rp.FreeTensor((k2, 1), ctx=ctx))
result.add_scal_prod(TensorTensorProduct((tk1, tk2), ctx), v)
return result
return function_product
def _convolve(f, g):
func = _tensor_product_functions(f, g)
def convolved(x):
return _concatenate(func(_adjoint_of_shuffle(x)), otype=type(x))
return convolved
def _remove_constant(x):
ctx = x.context
# noinspection PyArgumentList
empty_word = rp.TensorKey(width=ctx.width, depth=ctx.depth)
remover = type(x)((empty_word, x[empty_word]), ctx=ctx)
return x - remover
Tensor = TypeVar('Tensor')
# noinspection PyPep8Naming
def Log(x: Tensor) -> Tensor:
"""
Linear function on tensors that agrees with log on the group-like elements.
This function is the linear extension of the log function defined on the
group-like elements of the free tensor algebra (or the corresponding subset
of the shuffle tensor algebra) to the whole algebra. This implementation is
far from optimal.
:param x: Tensor (either a shuffle tensor or free tensor)
:return: Log(x) with the same type as the input.
"""
ctx = x.context
fn = _remove_constant
out = fn(x)
sign = False
for i in range(2, ctx.depth + 1):
sign = not sign
fn = _convolve(_remove_constant, fn)
tmp = fn(x) / i
if sign:
out -= tmp
else:
out += tmp
return out