1 of 31

Linen

refinement of

FLAX

a comfortable

2 of 31

  • Retain Flax’s concise, readable code for the common case of simple layers.
    • We spend way more time debugging and re-reading code our evil past selves wrote than writing it, put the logic in one place, make it easy to understand.
    • Our users value this.

  • Offer a more familiar Python syntax.
    • lightweight classes using dataclasses
    • sharing by python reference
    • minimize use of context managers

  • Support JAX Transformations and other Fancy Stuff™️ on Modules as simply as possible.
    • Build hard stuff on an explicit functional core that avoids Python OOP insanity.

  • Make it pleasant to use interactively.
    • provide helpful error messages whenever possible
    • make it easy to traverse the parameter state

Goals

3 of 31

Model

Decoder

Encoder

block1

block2

block1

block2

conv

BN

conv

BN

conv

BN

conv

BN

Core view of dataflow

4 of 31

Core view of dataflow

e.g. parameters,

batch norm stats

5 of 31

Core view of dataflow

6 of 31

Core view of dataflow

(Simplified)

7 of 31

Core view of dataflow

(Actually….)

8 of 31

Core view of dataflow

9 of 31

Core view of dataflow

10 of 31

from flax.core import Scope, init, apply

def simple_mlp(

scope: Scope, x: Array,

sizes: Sequence[int],

act_fn: Callable[[Array], Array] = nn.relu):

i = scope.variable('counter', 'i', jnp.zeros, ())

i.value += random.uniform(scope.make_rng('counter'), ())

# hidden layers

for size in sizes[:-1]:

x = scope.child(nn.dense)(x, size)

x = act_fn(x)

# output layer

return scope.child(nn.dense, 'out')(x, sizes[-1])

y, variables = init(my_module)(rngs, x, sizes=(2, 8, 4))

apply(my_module, mutable=('counter',))(variables, x2, rngs=rngs)

  • In functional-core, a module is just a function that takes a “scope” argument.�
  • A Scope object manages an explicit API for creating variables, RNGs, and sub modules.�
  • Variables and RNG sequences have an explicit kind

11 of 31

batch norm stats

autoregressive cache

differentiable parameters

Track separate ‘Kinds’ of variables

12 of 31

We want to remat, vmap, scan, etc. a module.

Lifting for JAX Transformations

13 of 31

Lifting for JAX Transformations

Transformation Boundary

When needed for the transform, we specify transformation rules (e.g. in-axes, out-axes) per-kind.

14 of 31

from flax.core import lift

# add a group dimension to dense

group_dense = lift.vmap(

nn.dense,

in_axes=0,

variable_in_axes={'param': 0},

variable_out_axes={'param': 0},

split_rngs={'param': True})

  • Lift module extends jax transforms with variables and RNGs�
  • Behavior is defined on a per-kind basis�
  • Limitation: lifted transform operates on scopes, not individual variables�

Lifting for JAX Transformations

15 of 31

from flax.core import lift

# resnet of n blocks with O(sqrt(n)) memory

# and O(1) compile time

def body_fn(scope, x):

return residual_block(scope, x)

# residual net with 10 x 10 = 100 blocks

return lift.remat_scan(

body_fn, scope, x, lengths=(10, 10),

variable_modes={'param': 'scan', 'batch_stats': 'scan'},

split_rngs={'param': True})

  • Compose transformations for more interesting combinations like scan(remat(scan(f)))
  • Behavior is unambiguous: Transforms explicitly define how various types of state (like batch stats) are handled�
  • lift.remat_scan is just a util combining lift.remat and lift.scan�

Lifting for JAX Transformations

16 of 31

Proposed User-facing Module API

17 of 31

class SimpleDense(nn.Module):

in_features: int

out_features: int

kernel_init: Callable = initializers.lecun_normal()

bias_init: Callable = initializers.zeros

use_bias: bool = True

def setup(self):

self.kernel = self.param('kernel', self.kernel_init,

(self.in_features, self.out_features))

if self.use_bias:

self.bias = self.param('bias', self.bias_init,

(self.features,))

def __call__(self, inputs):

y = lax.dot_general(inputs, self.kernel,

(((inputs.ndim - 1,), (0,)), ((), ())),)

if self.use_bias:

y = y + self.bias

return y

Module - define variables in setup, use in call

A Module always has a -single- naming scope for variables and submodules.

No sharing-by-name permitted, only sharing by reference.

We use setup() instead of __init__() because dataclasses take over __init__().

18 of 31

class SimpleDense(nn.Module):

# no `in_features` -- lazy initialization with shape inference

out_features: int

kernel_init: Callable = initializers.lecun_normal()

bias_init: Callable = initializers.zeros

use_bias: bool = True

@compact

def __call__(self, inputs):

kernel = self.param('kernel', self.kernel_init,

(inputs.shape[-1], self.out_features))

y = lax.dot_general(inputs, kernel,

(((inputs.ndim - 1,), (0,)), ((), ())),)

if self.use_bias:

bias = self.param('bias',

self.bias_init, (self.out_features,))

y = y + bias

return y

Module - colocating variable definition and use via @compact

For more concise definitions of modules, use @compact method decorator

Can only use on a single method (avoids complicated naming/reuse rules)

Let's you co-locate variable definition and use, can lead to clearer and shorter code.

Avoid the need to duplicate control flow (e.g. conditionals, loops) between setup and call.

Builds on JAX's shape inference -- no need for in_features

19 of 31

class SimpleDense(nn.Module):

out_features: int

kernel_init: Callable = initializers.lecun_normal()

bias_init: Callable = initializers.zeros

use_bias: bool = True

@compact

def __call__(self, inputs):

kernel = self.param('kernel', self.kernel_init,

(inputs.shape[-1], self.out_features))

y = lax.dot_general(inputs, kernel,

(((inputs.ndim - 1,), (0,)), ((), ())),)

if self.use_bias:

bias = self.param('bias',

self.bias_init, (self.out_features,))

y = y + bias

return y

hyperparameters as dataclass attrs

Module - colocating variable definition and use via @compact

20 of 31

class SimpleDense(nn.Module):

out_features: int

kernel_init: Callable = initializers.lecun_normal()

bias_init: Callable = initializers.zeros

use_bias: bool = True

@compact

def __call__(self, inputs):

kernel = self.param('kernel', self.kernel_init,

(inputs.shape[-1], self.out_features))

y = lax.dot_general(inputs, kernel,

(((inputs.ndim - 1,), (0,)), ((), ())),)

if self.use_bias:

bias = self.param('bias',

self.bias_init, (self.out_features,))

y = y + bias

return y

lazily initialized parameters:

if data not in scope, uses initializer,

otherwise returns existing data

Module - colocating variable definition and use via @compact

21 of 31

class DotProductAttention(nn.Module):

qkv_features: int

out_features: int

def setup(self):

QKVDense = functools.partial(Dense, features=self.qkv_features, use_bias=False)

self.query = QKVDense()

self.key = QKVDense()

self.value = QKVDense()

self.attn = RawDotProductAttention()

self.out = Dense(features=self.out_features)

def __call__(self, inputs_q, inputs_kv, bias=None):

query = self.query(inputs_q)

key = self.key(inputs_kv)

value = self.value(inputs_kv)

y = self.attn(query, key, value, bias=bias)

y = self.out(y)

return y

Dot Product Attention: Defining submodules during init

22 of 31

class DotProductAttention(nn.Module):

qkv_features: int

out_features: int

@nn.compact� def __call__(self, inputs_q, inputs_kv, bias=None):

QKVDense = functools.partial(Dense, features=self.qkv_features, use_bias=False)

query = QKVDense(name='query')(inputs_q)

key = QKVDense(name='key')(inputs_kv)

value = QKVDense(name='value')(inputs_kv)

y = RawDotProductAttention()(query, key, value, bias=bias)

y = Dense(features=self.out_features, name='out')(y)

return y

Dot Product Attention: The same with inline submodules

23 of 31

class DotProductAttention(nn.Module):

qkv_features: int

out_features: int

@compact

def __call__(self, inputs_q, inputs_kv, bias=None):

QKVDense = functools.partial(Dense, features=qkv_features, use_bias=False)

query = QKVDense(name='query')(inputs_q)

key = QKVDense(name='key')(inputs_kv)

value = QKVDense(name='value')(inputs_kv)

y = RawDotProductAttention()(query, key, value, bias=bias)

y = Dense(self, features=self.out_features, name='out')(y)

return y

Standard Module - inline submodules

Submodules also lazily initialized on call.

Submodules without explicit names are automatically named classname_N

24 of 31

class ResNet(nn.Module):

"""ResNetV1."""

num_classes: int

num_filters: int = 64

num_layers: int = 50

train: bool = True

dtype: Any = jnp.float32

@compact

def __call__(self, x):

block_sizes = _block_size_options[self.num_layers]

x = nn.Conv(self, self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=self.dtype, name='init_conv')(x)

x = nn.BatchNorm(self, use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=self.dtype, name='init_bn')(x)

x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

for i, block_size in enumerate(block_sizes):

for j in range(block_size):

strides = (2, 2) if i > 0 and j == 0 else (1, 1)

x = ResidualBlock(self, self.num_filters * 2 ** i, strides=strides, train=self.train, dtype=self.dtype)(x)

x = jnp.mean(x, axis=(1, 2))

x = nn.Dense(self, self.num_classes, dtype=self.dtype)(x)

x = jnp.asarray(x, jnp.float32)

x = nn.log_softmax(x)

return x

ResNet using inline submodules and for loops

25 of 31

# Turn single-headed attention into multi-headed attention:

Attn = flax.vmap(DotProductAttention,

in_axes=(None, None, None),

out_axes=-2,

axis_size=self.num_heads,

variable_axes={'param': 0},

split_rngs={'param': True,

'dropout': not self.broadcast_dropout})

# construct and call

y = Attn(self,

qkv_features=qkv_features // self.num_heads,

out_features=out_features,

)(inputs_q, inputs_kv, bias)

Transforming a Module - vmap example

jit, remat much simpler

scan similar in complexity

26 of 31

# Turn single-headed attention into multi-headed attention:

Attn = flax.vmap(DotProductAttention,

in_axes=(None, None, None),

out_axes=-2,

axis_size=self.num_heads,

variable_axes={'param': 0},

split_rngs={'param': True,

'dropout': not self.broadcast_dropout})

# construct and call

y = Attn(self,

qkv_features=qkv_features // self.num_heads,

out_features=out_features,

)(inputs_q, inputs_kv, bias)

Transform a Module

Transform a Module class

27 of 31

# Turn single-headed attention into multi-headed attention:

Attn = flax.vmap(DotProductAttention,

in_axes=(None, None, None),

out_axes=-2,

axis_size=self.num_heads,

variable_axes={'param': 0},

split_rngs={'param': True,

'dropout': not self.broadcast_dropout})

# construct and call

y = Attn(self,

qkv_features=qkv_features // self.num_heads,

out_features=out_features,

)(inputs_q, inputs_kv, bias)

Transform a Module

usual in_axes / out_axes

28 of 31

# Turn single-headed attention into multi-headed attention:

Attn = flax.vmap(DotProductAttention,

in_axes=(None, None, None),

out_axes=-2,

axis_size=self.num_heads,

variable_axes={'param': 0},

split_rngs={'param': True,

'dropout': not self.broadcast_dropout})

# construct and call

y = Attn(self,

qkv_features=qkv_features // self.num_heads,

out_features=out_features,

)(inputs_q, inputs_kv, bias)

Transform a Module

• per-kind in_axes / out_axes

• whether to split RNGs for param-init, dropout

29 of 31

class AutoEncoder(Module):

encoder_widths: Iterable

decoder_widths: Iterable

input_shape: Tuple = None

def setup(self):

self.encoder = MLP(self.encoder_widths)

self.decoder = MLP(self.decoder_widths +

(jnp.prod(self.input_shape), ))

def __call__(self, x):

return self.decode(self.encode(x))

def encode(self, x):

assert x.shape[-len(self.input_shape):] == self.input_shape

return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

def decode(self, z):

z = self.decoder(z)

x = nn.sigmoid(z)

x = jnp.reshape(x, (x.shape[0],) + self.input_shape)

return x

Modules with multiple methods

Can call methods from inside a module or outside, no special treatment.�

But! can't use @compact on more than one method.

30 of 31

class AutoEncoder(Module):

encoder_widths: Iterable

decoder_widths: Iterable

input_shape: Tuple = None

def setup(self):

self.encoder = MLP(self.encoder_widths)

self.decoder = MLP(self.decoder_widths +

(jnp.prod(self.input_shape), ))

def __call__(self, x):

return self.decode(self.encode(x))

def encode(self, x):

assert x.shape[-len(self.input_shape):] == self.input_shape

return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

def decode(self, z):

z = self.decoder(z)

x = nn.sigmoid(z)

x = jnp.reshape(x, (x.shape[0],) + self.input_shape)

return x

Modules with multiple methods

All parameters and submodules must be defined in setup() so that there’s a single correct view of data.

31 of 31

class AutoEncoder(Module):

encoder_widths: Iterable

decoder_widths: Iterable

input_shape: Tuple = None

def setup(self):

self.encoder = MLP(self.encoder_widths)

self.decoder = MLP(self.decoder_widths +

(jnp.prod(self.input_shape), ))

def __call__(self, x):

return self.decode(self.encode(x))

def encode(self, x):

assert x.shape[-len(self.input_shape):] == self.input_shape

return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

def decode(self, z):

z = self.decoder(z)

x = nn.sigmoid(z)

x = jnp.reshape(x, (x.shape[0],) + self.input_shape)

return x

Modules with multiple methods

Multiple methods exposed to inside or outside callers for autoencoders, transformer encoder-decoder, flow models, etc.