-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
loss_functions.py
653 lines (525 loc) · 22 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions to be used by LayerCollection."""
import abc
from typing import Tuple, Optional, Union, Sequence
import jax
import jax.numpy as jnp
from kfac_ferminet_alpha import distributions
from kfac_ferminet_alpha import layers_and_loss_tags as tags
from kfac_ferminet_alpha import utils
ArrayPair = Tuple[jnp.ndarray, jnp.ndarray]
FloatArray = Union[float, jnp.ndarray]
Index = Tuple[int]
class LossFunction(abc.ABC):
"""Abstract base class for loss functions.
Note that unlike typical loss functions used in neural networks these are
neither summed nor averaged over the batch and hence the output of evaluate()
will not be a scalar. It is up to the user to then to correctly manipulate
them as needed.
"""
def __init__(self, weight: FloatArray):
self._weight = weight
@property
def weight(self) -> FloatArray:
return self._weight
@property
@abc.abstractmethod
def targets(self) -> Optional[jnp.ndarray]:
"""The targets being predicted by the model.
Returns:
None or Tensor of appropriate shape for calling self._evaluate() on.
"""
pass
@property
@abc.abstractmethod
def inputs(self) -> Sequence[jnp.ndarray]:
"""The inputs to the loss function (excluding the targets)."""
pass
@abc.abstractmethod
def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]):
pass
def evaluate(
self,
targets: Optional[jnp.ndarray] = None,
coefficient_mode: str = "regular",
) -> jnp.ndarray:
"""Evaluate the loss function on the targets."""
if targets is None and self.targets is None:
raise ValueError("Cannot evaluate losses with unspecified targets.")
elif targets is None:
targets = self.targets
if coefficient_mode == "regular":
multiplier = self.weight
elif coefficient_mode == "sqrt":
multiplier = jnp.sqrt(self.weight)
elif coefficient_mode == "off":
multiplier = 1.0
else:
raise ValueError(f"Unrecognized coefficient_mode={coefficient_mode}.")
return self._evaluate(targets) * multiplier
@abc.abstractmethod
def _evaluate(self, targets: jnp.ndarray) -> jnp.ndarray:
"""Evaluates the negative log probability of the targets.
Args:
targets: Tensor that distribution can calculate log_prob() of.
Returns:
negative log probability of each target, summed across all targets.
"""
pass
def grad_of_evaluate(
self,
targets: Optional[jnp.ndarray],
coefficient_mode: str,
) -> Sequence[jnp.ndarray]:
"""Evaluates the gradient of the loss function.
Note that the targets of the loss must not be `None`.
Args:
targets: The potential targets on which to evaluate the gradient.
coefficient_mode: The coefficient mode to use for evaluation.
Returns:
The gradient of the loss evaluation function with respect to the inputs.
"""
def evaluate_sum(inputs: Sequence[jnp.ndarray]) -> jnp.ndarray:
instance = self.copy_with_different_inputs(inputs)
return jnp.sum(instance.evaluate(targets, coefficient_mode))
return jax.grad(evaluate_sum)(self.inputs)
def multiply_ggn(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Right-multiply a vector by the GGN.
Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)
of the loss function with respect to its inputs.
Args:
vector: The vector to multiply. Must be the same shape(s) as the 'inputs'
property.
Returns:
The vector right-multiplied by the GGN. Will be of the same shape(s)
as the 'inputs' property.
"""
return utils.scalar_mul(self.multiply_ggn_unweighted(vector), self.weight)
@abc.abstractmethod
def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Same as `multiply_ggn`, but without taking into account the weight."""
pass
def multiply_ggn_factor(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Right-multiply a vector by a factor B of the GGN.
Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
Note that B can be any matrix satisfying B * B^T = G where G is the GGN,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be of the shape given by the
'ggn_factor_inner_shape' property.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
return utils.scalar_mul(
self.multiply_ggn_factor_unweighted(vector), jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Same as `multiply_ggn_factor`, but without taking into account the weight."""
pass
def multiply_ggn_factor_transpose(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Right-multiply a vector by the transpose of a factor B of the GGN.
Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
Note that B can be any matrix satisfying B * B^T = G where G is the GGN,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be the same shape(s) as the 'inputs'
property.
Returns:
The vector right-multiplied by B^T. Will be of the shape given by the
'ggn_factor_inner_shape' property.
"""
return utils.scalar_mul(
self.multiply_ggn_factor_transpose_unweighted(vector),
jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_ggn_factor_transpose_unweighted(
self,
vector: jnp.ndarray
) -> jnp.ndarray:
"""Same as `multiply_ggn_factor_transpose`, but without taking into account the weight."""
pass
def multiply_ggn_factor_replicated_one_hot(self, index: Index) -> jnp.ndarray:
"""Right-multiply a replicated-one-hot vector by a factor B of the GGN.
Here the 'GGN' is the GGN matrix (whose definition is slightly flexible)
of the loss function with respect to its inputs. Typically this will be
block-diagonal across different cases in the batch, since the loss function
is typically summed across cases.
A 'replicated-one-hot' vector means a tensor which, for each slice along the
batch dimension (assumed to be dimension 0), is 1.0 in the entry
corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying B * B^T = G where G is the GGN,
but will agree with the one used in the other methods of this class.
Args:
index: A tuple representing in the index of the entry in each slice that
is 1.0. Note that len(index) must be equal to the number of elements of
the 'ggn_factor_inner_shape' tensor minus one.
Returns:
The vector right-multiplied by B^T. Will be of the same shape(s) as the
'inputs' property.
"""
return utils.scalar_mul(
self.multiply_ggn_factor_replicated_one_hot_unweighted(index),
jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_ggn_factor_replicated_one_hot_unweighted(
self,
index: Index
) -> jnp.ndarray:
pass
@property
@abc.abstractmethod
def ggn_factor_inner_shape(self) -> Sequence[int]:
"""The shape of the tensor returned by multiply_ggn_factor."""
pass
class NegativeLogProbLoss(LossFunction):
"""Abstract base class for loss functions that are negative log probs."""
@property
def inputs(self):
return self.params
@property
@abc.abstractmethod
def params(self):
"""Parameters to the underlying distribution."""
pass
def multiply_fisher(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Right-multiply a vector by the Fisher.
Args:
vector: The vector to multiply. Must be the same shape(s) as the 'inputs'
property.
Returns:
The vector right-multiplied by the Fisher. Will be of the same shape(s)
as the 'inputs' property.
"""
return utils.scalar_mul(
self.multiply_fisher_unweighted(vector), self.weight)
@abc.abstractmethod
def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
pass
def multiply_fisher_factor(self, vector: jnp.ndarray) -> jnp.ndarray:
"""Right-multiply a vector by a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be of the shape given by the
'fisher_factor_inner_shape' property.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
return utils.scalar_mul(
self.multiply_fisher_factor_unweighted(vector), jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_fisher_factor_unweighted(
self,
vector: jnp.ndarray
) -> jnp.ndarray:
pass
def multiply_fisher_factor_transpose(
self,
vector: jnp.ndarray
) -> jnp.ndarray:
"""Right-multiply a vector by the transpose of a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
vector: The vector to multiply. Must be the same shape(s) as the 'inputs'
property.
Returns:
The vector right-multiplied by B^T. Will be of the shape given by the
'fisher_factor_inner_shape' property.
"""
return utils.scalar_mul(
self.multiply_fisher_factor_transpose_unweighted(vector),
jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_fisher_factor_transpose_unweighted(
self,
vector: jnp.ndarray
) -> jnp.ndarray:
pass
def multiply_fisher_factor_replicated_one_hot(
self,
index: Index
) -> jnp.ndarray:
"""Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
product of gradients) with respect to the parameters of the underlying
probability distribution (whose log-prob defines the loss). Typically this
will be block-diagonal across different cases in the batch, since the
distribution is usually (but not always) conditionally iid across different
cases.
A 'replicated-one-hot' vector means a tensor which, for each slice along the
batch dimension (assumed to be dimension 0), is 1.0 in the entry
corresponding to the given index and 0 elsewhere.
Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
but will agree with the one used in the other methods of this class.
Args:
index: A tuple representing in the index of the entry in each slice that
is 1.0. Note that len(index) must be equal to the number of elements of
the 'fisher_factor_inner_shape' tensor minus one.
Returns:
The vector right-multiplied by B. Will be of the same shape(s) as the
'inputs' property.
"""
return utils.scalar_mul(
self.multiply_fisher_factor_replicated_one_hot_unweighted(index),
jnp.sqrt(self.weight))
@abc.abstractmethod
def multiply_fisher_factor_replicated_one_hot_unweighted(
self,
index: Index
) -> jnp.ndarray:
pass
@property
@abc.abstractmethod
def fisher_factor_inner_shape(self) -> Sequence[int]:
"""The shape of the tensor returned by multiply_fisher_factor."""
pass
@abc.abstractmethod
def sample(self, rng_key: jnp.ndarray) -> jnp.ndarray:
"""Sample 'targets' from the underlying distribution."""
pass
def grad_of_evaluate_on_sample(
self,
rng_key: jnp.ndarray,
coefficient_mode: str,
) -> Sequence[jnp.ndarray]:
"""Evaluates the gradient of the log probability on a random sample.
Args:
rng_key: Jax PRNG key for sampling.
coefficient_mode: The coefficient mode to use for evaluation.
Returns:
The gradient of the log probability of targets sampled from the
distribution.
"""
return self.grad_of_evaluate(self.sample(rng_key), coefficient_mode)
class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss, abc.ABC):
"""Base class for neg log prob losses whose inputs are 'natural' parameters.
We will take the GGN of the loss to be the Fisher associated with the
distribution, which also happens to be equal to the Hessian for this class
of loss functions. See here: https://arxiv.org/abs/1412.1193
'Natural parameters' are defined for exponential-family models. See for
example: https://en.wikipedia.org/wiki/Exponential_family
"""
def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
return self.multiply_fisher_unweighted(vector)
def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
return self.multiply_fisher_factor_unweighted(vector)
def multiply_ggn_factor_transpose_unweighted(
self,
vector: jnp.ndarray
) -> jnp.ndarray:
return self.multiply_fisher_factor_transpose_unweighted(vector)
def multiply_ggn_factor_replicated_one_hot_unweighted(
self,
index: Index
) -> jnp.ndarray:
return self.multiply_fisher_factor_replicated_one_hot_unweighted(index)
@property
def ggn_factor_inner_shape(self) -> Sequence[int]:
return self.fisher_factor_inner_shape
class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
"""Base class for neg log prob losses that use the distribution classes."""
@property
@abc.abstractmethod
def dist(self):
"""The underlying distribution instance."""
pass
def _evaluate(self, targets: jnp.ndarray):
return -self.dist.log_prob(targets)
def sample(self, rng_key: jnp.ndarray):
return self.dist.sample(seed=rng_key)
@property
def fisher_factor_inner_shape(self) -> Sequence[int]:
return self.dist.mean().shape
class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
NaturalParamsNegativeLogProbLoss):
"""Neg log prob loss for a normal distribution parameterized by a mean vector.
Note that the covariance is treated as the identity divided by 2.
Also note that the Fisher for such a normal distribution with respect the mean
parameter is given by:
F = (1 / variance) * I
See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
"""
def __init__(
self,
mean: jnp.ndarray,
targets: Optional[jnp.ndarray] = None,
variance: float = 0.5,
weight: float = 1.0,
):
super().__init__(weight=weight)
self._mean = mean
self._targets = targets
self._variance = variance
if not isinstance(variance, float):
raise ValueError("The `variance` argument should be python float.")
@property
def targets(self) -> Optional[jnp.ndarray]:
return self._targets
@property
def dist(self):
scale_diag = jnp.full_like(self._mean, jnp.sqrt(self._variance))
return distributions.MultivariateNormalDiag(self._mean, scale_diag)
@property
def params(self):
return self._mean,
def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]):
[mean] = inputs
return NormalMeanNegativeLogProbLoss(
mean=mean,
targets=self.targets,
variance=self._variance,
weight=self.weight,
)
def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray:
return vector / self._variance
def multiply_fisher_factor_unweighted(
self,
vector: jnp.ndarray,
) -> jnp.ndarray:
return vector / jnp.sqrt(self._variance)
def multiply_fisher_factor_transpose_unweighted(
self,
vector: jnp.ndarray,
) -> jnp.ndarray:
return self.multiply_fisher_factor_unweighted(vector) # it's symmetric
def multiply_fisher_factor_replicated_one_hot_unweighted(
self,
index: Index,
) -> jnp.ndarray:
assert len(index) == 1, f"Length of index was {len(index)}."
index = index[0]
ones_slice = jnp.ones([self._mean.shape[0]])[..., None]
output_slice = ones_slice / jnp.sqrt(self._variance)
return insert_slice_in_zeros(output_slice, 1, self._mean.shape[1], index)
def insert_slice_in_zeros(
slice_to_insert: jnp.ndarray,
dim: int,
dim_size: int,
position: int,
) -> jnp.ndarray:
"""Inserts slice into a larger tensor of zeros.
Forms a new tensor which is the same shape as slice_to_insert, except that
the dimension given by 'dim' is expanded to the size given by 'dim_size'.
'position' determines the position (index) at which to insert the slice within
that dimension.
Assumes slice_to_insert.shape[dim] = 1.
Args:
slice_to_insert: The slice to insert.
dim: The dimension which to expand with zeros.
dim_size: The new size of the 'dim' dimension.
position: The position of 'slice_to_insert' in the new tensor.
Returns:
The new tensor.
Raises:
ValueError: If the slice's shape at the given dim is not 1.
"""
slice_shape = slice_to_insert.shape
if slice_shape[dim] != 1:
raise ValueError(f"Expected slice_to_insert.shape to have {dim} dim of 1,"
f" but was {slice_to_insert.shape[dim]}.")
before = [0] * int(len(slice_shape))
after = before[:]
before[dim] = position
after[dim] = dim_size - position - 1
return jnp.pad(slice_to_insert, list(zip(before, after)))
# _______ _____ _ _ _ _
# |__ __| | __ \ (_) | | | | (_)
# | | __ _ __ _ | |__) |___ __ _ _ ___| |_ _ __ __ _| |_ _ ___ _ __
# | |/ _` |/ _` | | _ // _ \/ _` | / __| __| '__/ _` | __| |/ _ \| '_ \
# | | (_| | (_| | | | \ \ __/ (_| | \__ \ |_| | | (_| | |_| | (_) | | | |
# |_|\__,_|\__, | |_| \_\___|\__, |_|___/\__|_| \__,_|\__|_|\___/|_| |_|
# __/ | __/ |
# |___/ |___/
NormalMeanNegativeLogProbLoss_tag = tags.LossTag(
NormalMeanNegativeLogProbLoss, num_inputs=1)
def register_normal_predictive_distribution(
mean: jnp.ndarray,
targets: Optional[jnp.ndarray] = None,
variance: float = 0.5,
weight: float = 1.0,
):
"""Registers a normal predictive distribution.
This corresponds to a squared error loss of the form
weight/(2*var) * ||target - mean||^2
Args:
mean: A tensor defining the mean vector of the distribution. The first
dimension must be the batch size.
targets: (OPTIONAL) The targets for the loss function. Only required if one
wants to use the "empirical Fisher" instead of the true Fisher (which is
controlled by the 'estimation_mode' to the optimizer).
(Default: None)
variance: float. The variance of the distribution. Note that the default
value of 0.5 corresponds to a standard squared error loss weight *
||target - prediction||^2. If you want your squared error loss to be of
the form 0.5*coeff*||target - prediction||^2 you should use
variance=1.0.
(Default: 0.5)
weight: A scalar coefficient to multiply the log prob loss associated with
this distribution. The Fisher will be multiplied by the corresponding
factor. In general this is NOT equivalent to changing the temperature of
the distribution, but in the ase of normal distributions it may be.
(Default: 1.0)
Returns:
The mean and targets as dependable on the tag.
"""
if targets is None:
targets = jnp.zeros_like(mean)
return NormalMeanNegativeLogProbLoss_tag.bind(
mean, targets, variance=variance, weight=weight, return_loss=False)
def register_squared_error_loss(
prediction: jnp.ndarray,
targets: Optional[jnp.ndarray] = None,
weight: float = 1.0,
):
"""Registers a squared error loss function.
This assumes the squared error loss of the form ||target - prediction||^2,
averaged across the mini-batch. If your loss uses a coefficient of 0.5
you need to set the "weight" argument to reflect this.
Args:
prediction: The prediction made by the network (i.e. its output). The first
dimension must be the batch size.
targets: (OPTIONAL) The targets for the loss function. Only required if one
wants to use the "empirical Fisher" instead of the true Fisher (which is
controlled by the 'estimation_mode' to the optimizer).
(Default: None)
weight: A float coefficient to multiply the loss function by.
(Default: 1.0)
Returns:
The mean and targets as dependable on the tag.
"""
return register_normal_predictive_distribution(
prediction, targets=targets, variance=0.5, weight=weight)