forked from louaaron/Score-Entropy-Discrete-Diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
noise_lib.py
73 lines (57 loc) · 1.99 KB
/
noise_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import abc
import torch
import torch.nn as nn
import numpy as np
def get_noise(config):
if config.noise.type == "geometric":
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
elif config.noise.type == "loglinear":
return LogLinearNoise()
else:
raise ValueError(f"{config.noise.type} is not a valid noise")
class Noise(abc.ABC, nn.Module):
"""
Baseline forward method to get the total + rate of noise at a timestep
"""
def forward(self, t):
return self.total_noise(t), self.rate_noise(t)
"""
Assume time goes from 0 to 1
"""
@abc.abstractmethod
def rate_noise(self, t):
"""
Rate of change of noise ie g(t)
"""
pass
@abc.abstractmethod
def total_noise(self, t):
"""
Total noise ie \int_0^t g(t) dt + g(0)
"""
pass
class GeometricNoise(Noise, nn.Module):
def __init__(self, sigma_min=1e-3, sigma_max=1, learnable=False):
super().__init__()
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
if learnable:
self.sigmas = nn.Parameter(self.sigmas)
self.empty = nn.Parameter(torch.tensor(0.0))
def rate_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
def total_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
class LogLinearNoise(Noise, nn.Module):
"""
Log Linear noise schedule built so that 1 - 1/e^(n(t)) interpolates between 0 and ~1
when t goes from 0 to 1. Used for absorbing
Total noise is -log(1 - (1 - eps) * t), so the sigma will be (1 - eps) * t
"""
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.empty = nn.Parameter(torch.tensor(0.0))
def rate_noise(self, t):
return (1 - self.eps) / (1 - (1 - self.eps) * t)
def total_noise(self, t):
return -torch.log1p(-(1 - self.eps) * t)