-
Moved image related layers to
serket.image
-
ScanRNN
changes:cell.init_state
is deprecated usesk.tree_state(cell, ...)
instead.
-
Naming changes:
***_init_func
->***_init
shorter and more concisegamma_init_func
->weight_init
beta_init_func
->bias_init
act_func
->act
-
MLP
produces smallerjaxprs
and are faster to compile. for my use case -higher order differentiation throughPINN
- the newMLP
is faster to compile. -
kernel_dilation
->dilation
-
input_dilation
-> Removed. -
p
->drop_rate
in all dropout layers -
FlipLeftRight2D
->HorizontalFlip2D
-
FlipUpDown2D
->VerticalFlip2D
-
sk.nn.{Sequential,RandomChoice}
tosk.{Sequential,RandomChoice}
. as they are applicable to other modules and not specfic tonn
-
tree_eval
: a dispatcher to define layers evaluation rule. for exampleDropout
is changed toIdentity
whentree_eval
is applied.@sk.tree_eval.def_eval(sk.nn.Dropout) def dropout_evaluation(_) -> sk.nn.Identity: return sk.nn.Identity()
-
tree_state
: a dispatcher to define state intialization forBatchNorm
,RNN
cells.@sk.tree_state.def_state(sk.nn.SimpleRNNCell) def simple_rnn_init_state(cell: SimpleRNNCell, array: jax.Array | None) -> SimpleRNNState: del kwargs # unused return SimpleRNNState(jnp.zeros([cell.hidden_features]))
-
MultiHeadAttention
-
BatchNorm
-
RandomHorizontalShear2D
-
RandomPerspective2D
-
RandomRotate2D
-
RandomVerticalShear2D
-
Rotate2D
-
VerticalShear2D
-
Pixelate2D
-
Solarize2D
-
Posterize2D
-
RandomJigSaw2D
-
FFTAvgBlur2D
-
FFTGaussianBlur2D
Bilinear
is deprecated, useMultilinear((in1_features, in2_features), out_features)
HistogramEqualization2D
- Remove
.blocks
, and will move it to examples