-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcommon.py
133 lines (103 loc) · 4.04 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import collections
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
Batch = collections.namedtuple(
'Batch',
['observations', 'actions', 'rewards', 'masks', 'next_observations'])
def default_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.orthogonal(scale)
PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any # this could be a real type?
InfoDict = Dict[str, float]
def _l2_normalize(x, eps=1e-4):
return x * jax.lax.rsqrt((x ** 2).sum() + eps)
def _l2_norm(x):
return jnp.sqrt((x ** 2).sum())
def _power_iteration(A, u, n_steps=10):
"""Update an estimate of the first right-singular vector of A()."""
def fun(u, _):
v, A_transpose = jax.vjp(A, u)
u, = A_transpose(v)
u = _l2_normalize(u)
return u, None
u, _ = lax.scan(fun, u, xs=None, length=n_steps)
return u
def estimate_spectral_norm(f, x, seed=0, n_steps=10):
"""Estimate the spectral norm of f(x) linearized at x."""
rng = jax.random.PRNGKey(seed)
u0 = jax.random.normal(rng, x.shape)
_, f_jvp = jax.linearize(f, x)
u = _power_iteration(f_jvp, u0, n_steps)
sigma = _l2_norm(f_jvp(u))
return sigma
class MLP(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activate_final: int = False
layer_norm: bool = False
dropout_rate: Optional[float] = None
@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=default_init())(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
if self.layer_norm:
x = nn.LayerNorm()(x)
x = self.activations(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training)
return x
@flax.struct.dataclass
class Model:
step: int
apply_fn: nn.Module = flax.struct.field(pytree_node=False)
params: Params
tx: Optional[optax.GradientTransformation] = flax.struct.field(
pytree_node=False)
opt_state: Optional[optax.OptState] = None
@classmethod
def create(cls,
model_def: nn.Module,
inputs: Sequence[jnp.ndarray],
tx: Optional[optax.GradientTransformation] = None) -> 'Model':
variables = model_def.init(*inputs)
_, params = variables.pop('params')
if tx is not None:
opt_state = tx.init(params)
else:
opt_state = None
return cls(step=1,
apply_fn=model_def,
params=params,
tx=tx,
opt_state=opt_state)
def __call__(self, *args, **kwargs):
return self.apply_fn.apply({'params': self.params}, *args, **kwargs)
def apply(self, *args, **kwargs):
return self.apply_fn.apply(*args, **kwargs)
def apply_gradient(self, loss_fn) -> Tuple[Any, 'Model']:
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, info = grad_fn(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state,
self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(step=self.step + 1,
params=new_params,
opt_state=new_opt_state), info
def save(self, save_path: str):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
f.write(flax.serialization.to_bytes(self.params))
def load(self, load_path: str) -> 'Model':
with open(load_path, 'rb') as f:
params = flax.serialization.from_bytes(self.params, f.read())
return self.replace(params=params)