-
Notifications
You must be signed in to change notification settings - Fork 1
/
annealing.py
83 lines (68 loc) · 2.94 KB
/
annealing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import math
class Annealer:
"""
This class is used to anneal the KL divergence loss over the course of training VAEs.
After each call, the step() function should be called to update the current epoch.
"""
def __init__(self, total_steps, shape='linear', baseline=0.0, cyclical=False, disable=False):
"""
Parameters:
total_steps (int): Number of epochs to reach full KL divergence weight.
shape (str): Shape of the annealing function. Can be 'linear', 'cosine', or 'logistic'.
baseline (float): Starting value for the annealing function [0-1]. Default is 0.0.
cyclical (bool): Whether to repeat the annealing cycle after total_steps is reached.
disable (bool): If true, the __call__ method returns unchanged input (no annealing).
"""
self.current_step = 0
if shape not in ['linear', 'cosine', 'logistic']:
raise ValueError("Shape must be one of 'linear', 'cosine', or 'logistic.")
self.shape = shape
if not 0 <= float(baseline) <= 1:
raise ValueError("Baseline must be a float between 0 and 1.")
self.baseline = baseline
if type(total_steps) is not int or total_steps < 1:
raise ValueError("Argument total_steps must be an integer greater than 0")
self.total_steps = total_steps
if type(cyclical) is not bool:
raise ValueError("Argument cyclical must be a boolean.")
self.cyclical = cyclical
if type(disable) is not bool:
raise ValueError("Argument disable must be a boolean.")
self.disable = disable
def __call__(self, kld):
"""
Args:
kld (torch.tensor): KL divergence loss
Returns:
out (torch.tensor): KL divergence loss multiplied by the slope of the annealing function.
"""
if self.disable:
return kld
out = kld * self._slope()
return out
def step(self):
if self.current_step < self.total_steps:
self.current_step += 1
if self.cyclical and self.current_step >= self.total_steps:
self.current_step = 0
return
def set_cyclical(self, value):
if not isinstance(value, bool):
raise ValueError("Argument to cyclical method must be a boolean.")
self.cyclical = value
return
def _slope(self):
if self.shape == 'linear':
y = (self.current_step / self.total_steps)
elif self.shape == 'cosine':
y = (math.cos(math.pi * (self.current_step / self.total_steps - 1)) + 1) / 2
elif self.shape == 'logistic':
exponent = ((self.total_steps / 2) - self.current_step)
y = 1 / (1 + math.exp(exponent))
else:
y = 1.0
y = self._add_baseline(y)
return y
def _add_baseline(self, y):
y_out = y * (1 - self.baseline) + self.baseline
return y_out