Linen
refinement of
FLAX
a comfortable
Goals
Model
Decoder
Encoder
block1
block2
block1
block2
conv
BN
conv
BN
conv
BN
conv
BN
Core view of dataflow
Core view of dataflow
e.g. parameters,
batch norm stats
Core view of dataflow
Core view of dataflow
(Simplified)
Core view of dataflow
(Actually….)
Core view of dataflow
Core view of dataflow
from flax.core import Scope, init, apply
def simple_mlp(
scope: Scope, x: Array,
sizes: Sequence[int],
act_fn: Callable[[Array], Array] = nn.relu):
i = scope.variable('counter', 'i', jnp.zeros, ())
i.value += random.uniform(scope.make_rng('counter'), ())
# hidden layers
for size in sizes[:-1]:
x = scope.child(nn.dense)(x, size)
x = act_fn(x)
# output layer
return scope.child(nn.dense, 'out')(x, sizes[-1])
y, variables = init(my_module)(rngs, x, sizes=(2, 8, 4))
apply(my_module, mutable=('counter',))(variables, x2, rngs=rngs)
batch norm stats
autoregressive cache
differentiable parameters
Track separate ‘Kinds’ of variables
We want to remat, vmap, scan, etc. a module.
Lifting for JAX Transformations
Lifting for JAX Transformations
Transformation Boundary
When needed for the transform, we specify transformation rules (e.g. in-axes, out-axes) per-kind.
from flax.core import lift
# add a group dimension to dense
group_dense = lift.vmap(
nn.dense,
in_axes=0,
variable_in_axes={'param': 0},
variable_out_axes={'param': 0},
split_rngs={'param': True})
Lifting for JAX Transformations
from flax.core import lift
# resnet of n blocks with O(sqrt(n)) memory
# and O(1) compile time
def body_fn(scope, x):
return residual_block(scope, x)
# residual net with 10 x 10 = 100 blocks
return lift.remat_scan(
body_fn, scope, x, lengths=(10, 10),
variable_modes={'param': 'scan', 'batch_stats': 'scan'},
split_rngs={'param': True})
Lifting for JAX Transformations
Proposed User-facing Module API
class SimpleDense(nn.Module):
in_features: int
out_features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
use_bias: bool = True
def setup(self):
self.kernel = self.param('kernel', self.kernel_init,
(self.in_features, self.out_features))
if self.use_bias:
self.bias = self.param('bias', self.bias_init,
(self.features,))
def __call__(self, inputs):
y = lax.dot_general(inputs, self.kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
if self.use_bias:
y = y + self.bias
return y
Module - define variables in setup, use in call
A Module always has a -single- naming scope for variables and submodules.
No sharing-by-name permitted, only sharing by reference.
We use setup() instead of __init__() because dataclasses take over __init__().
class SimpleDense(nn.Module):
# no `in_features` -- lazy initialization with shape inference
out_features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
use_bias: bool = True
@compact
def __call__(self, inputs):
kernel = self.param('kernel', self.kernel_init,
(inputs.shape[-1], self.out_features))
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
if self.use_bias:
bias = self.param('bias',
self.bias_init, (self.out_features,))
y = y + bias
return y
Module - colocating variable definition and use via @compact
For more concise definitions of modules, use @compact method decorator
Can only use on a single method (avoids complicated naming/reuse rules)
Let's you co-locate variable definition and use, can lead to clearer and shorter code.
Avoid the need to duplicate control flow (e.g. conditionals, loops) between setup and call.
Builds on JAX's shape inference -- no need for in_features
class SimpleDense(nn.Module):
out_features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
use_bias: bool = True
@compact
def __call__(self, inputs):
kernel = self.param('kernel', self.kernel_init,
(inputs.shape[-1], self.out_features))
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
if self.use_bias:
bias = self.param('bias',
self.bias_init, (self.out_features,))
y = y + bias
return y
hyperparameters as dataclass attrs
Module - colocating variable definition and use via @compact
class SimpleDense(nn.Module):
out_features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
use_bias: bool = True
@compact
def __call__(self, inputs):
kernel = self.param('kernel', self.kernel_init,
(inputs.shape[-1], self.out_features))
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
if self.use_bias:
bias = self.param('bias',
self.bias_init, (self.out_features,))
y = y + bias
return y
lazily initialized parameters:
if data not in scope, uses initializer,
otherwise returns existing data
Module - colocating variable definition and use via @compact
class DotProductAttention(nn.Module):
qkv_features: int
out_features: int
def setup(self):
QKVDense = functools.partial(Dense, features=self.qkv_features, use_bias=False)
self.query = QKVDense()
self.key = QKVDense()
self.value = QKVDense()
self.attn = RawDotProductAttention()
self.out = Dense(features=self.out_features)
def __call__(self, inputs_q, inputs_kv, bias=None):
query = self.query(inputs_q)
key = self.key(inputs_kv)
value = self.value(inputs_kv)
y = self.attn(query, key, value, bias=bias)
y = self.out(y)
return y
Dot Product Attention: Defining submodules during init
class DotProductAttention(nn.Module):
qkv_features: int
out_features: int
@nn.compact� def __call__(self, inputs_q, inputs_kv, bias=None):
QKVDense = functools.partial(Dense, features=self.qkv_features, use_bias=False)
query = QKVDense(name='query')(inputs_q)
key = QKVDense(name='key')(inputs_kv)
value = QKVDense(name='value')(inputs_kv)
y = RawDotProductAttention()(query, key, value, bias=bias)
y = Dense(features=self.out_features, name='out')(y)
return y
Dot Product Attention: The same with inline submodules
class DotProductAttention(nn.Module):
qkv_features: int
out_features: int
@compact
def __call__(self, inputs_q, inputs_kv, bias=None):
QKVDense = functools.partial(Dense, features=qkv_features, use_bias=False)
query = QKVDense(name='query')(inputs_q)
key = QKVDense(name='key')(inputs_kv)
value = QKVDense(name='value')(inputs_kv)
y = RawDotProductAttention()(query, key, value, bias=bias)
y = Dense(self, features=self.out_features, name='out')(y)
return y
Standard Module - inline submodules
Submodules also lazily initialized on call.
Submodules without explicit names are automatically named classname_N
class ResNet(nn.Module):
"""ResNetV1."""
num_classes: int
num_filters: int = 64
num_layers: int = 50
train: bool = True
dtype: Any = jnp.float32
@compact
def __call__(self, x):
block_sizes = _block_size_options[self.num_layers]
x = nn.Conv(self, self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=self.dtype, name='init_conv')(x)
x = nn.BatchNorm(self, use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=self.dtype, name='init_bn')(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(block_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = ResidualBlock(self, self.num_filters * 2 ** i, strides=strides, train=self.train, dtype=self.dtype)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self, self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, jnp.float32)
x = nn.log_softmax(x)
return x
ResNet using inline submodules and for loops
# Turn single-headed attention into multi-headed attention:
Attn = flax.vmap(DotProductAttention,
in_axes=(None, None, None),
out_axes=-2,
axis_size=self.num_heads,
variable_axes={'param': 0},
split_rngs={'param': True,
'dropout': not self.broadcast_dropout})
# construct and call
y = Attn(self,
qkv_features=qkv_features // self.num_heads,
out_features=out_features,
)(inputs_q, inputs_kv, bias)
Transforming a Module - vmap example
jit, remat much simpler
scan similar in complexity
# Turn single-headed attention into multi-headed attention:
Attn = flax.vmap(DotProductAttention,
in_axes=(None, None, None),
out_axes=-2,
axis_size=self.num_heads,
variable_axes={'param': 0},
split_rngs={'param': True,
'dropout': not self.broadcast_dropout})
# construct and call
y = Attn(self,
qkv_features=qkv_features // self.num_heads,
out_features=out_features,
)(inputs_q, inputs_kv, bias)
Transform a Module
Transform a Module class
# Turn single-headed attention into multi-headed attention:
Attn = flax.vmap(DotProductAttention,
in_axes=(None, None, None),
out_axes=-2,
axis_size=self.num_heads,
variable_axes={'param': 0},
split_rngs={'param': True,
'dropout': not self.broadcast_dropout})
# construct and call
y = Attn(self,
qkv_features=qkv_features // self.num_heads,
out_features=out_features,
)(inputs_q, inputs_kv, bias)
Transform a Module
usual in_axes / out_axes
# Turn single-headed attention into multi-headed attention:
Attn = flax.vmap(DotProductAttention,
in_axes=(None, None, None),
out_axes=-2,
axis_size=self.num_heads,
variable_axes={'param': 0},
split_rngs={'param': True,
'dropout': not self.broadcast_dropout})
# construct and call
y = Attn(self,
qkv_features=qkv_features // self.num_heads,
out_features=out_features,
)(inputs_q, inputs_kv, bias)
Transform a Module
• per-kind in_axes / out_axes
• whether to split RNGs for param-init, dropout
class AutoEncoder(Module):
encoder_widths: Iterable
decoder_widths: Iterable
input_shape: Tuple = None
def setup(self):
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths +
(jnp.prod(self.input_shape), ))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[-len(self.input_shape):] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
Modules with multiple methods
Can call methods from inside a module or outside, no special treatment.�
But! can't use @compact on more than one method.
class AutoEncoder(Module):
encoder_widths: Iterable
decoder_widths: Iterable
input_shape: Tuple = None
def setup(self):
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths +
(jnp.prod(self.input_shape), ))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[-len(self.input_shape):] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
Modules with multiple methods
All parameters and submodules must be defined in setup() so that there’s a single correct view of data.
class AutoEncoder(Module):
encoder_widths: Iterable
decoder_widths: Iterable
input_shape: Tuple = None
def setup(self):
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths +
(jnp.prod(self.input_shape), ))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[-len(self.input_shape):] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
Modules with multiple methods
Multiple methods exposed to inside or outside callers for autoencoders, transformer encoder-decoder, flow models, etc.