ChainerX Code Reading
2. Chainer interoperability
2019-06-13
Chainer Core, Akifumi Imanishi
Schedule
2. Chainer interoperability
Support device classes from v6
model = L.Classifier(MLP(args.unit, 10))
if args.gpu >= 0:
# Make a specified GPU current
chainer.backends.cuda.get_device_from_id(args.gpu).use()
model.to_gpu() # Copy the model to the GPU
device = chainer.get_device(args.device)
model = L.Classifier(MLP(args.unit, 10))
model.to_device(device)
chainer.get_device(device_spec)
>>> chainer.get_device('@numpy') # NumPy
<CpuDevice (numpy)>
>>> chainer.get_device('@cupy:0') # CuPy
<GpuDevice (cupy):0>
>>> chainer.get_device('@intel64') # iDeep
<Intel64Device>
>>> chainer.get_device('native:0') # ChainerX native
<ChainerxDevice native:0>
>>> chainer.get_device('cuda:0') # ChainerX cuda
<ChainerxDevice cuda:0>
device.send
>>> x = cupy.array([1, 2, 3]) # @cupy:0
>>> device = chainer.get_device('cuda:0')
>>> device.send(x) # converts with zero-copy
array([1, 2, 3], shape=(3,), dtype=int64, device='cuda:0')
2. Chainer interoperability
How to support ChainerX in FunctionNode
class ReLU(function_node.FunctionNode):
def forward_chainerx(self, inputs):
x, = inputs
return chainerx.maximum(x, 0),
def forward_cpu(self, inputs):
x, = inputs
y = numpy.maximum(x, 0, dtype=x.dtype)
self.retain_outputs((0,))
return utils.force_array(y),
...
How to write fallback FunctionNode
class MaxPoolingND(pooling_nd._PoolingND):
def forward_chainerx(self, inputs):
x, = inputs
if self.return_indices:
return chainer.Fallback
if x.device.backend.name == 'cuda':
if self.ndim not in [2, 3]:
return chainer.Fallback
return chainerx.max_pool(
x, self.ksize, self.stride, self.pad, self.cover_all),
...
Use ChainerX ndarrays with Chainer.Variable
>>> array = chainerx.arange(5, dtype='float32')
>>> x = Variable(array)
>>> y = F.sum(x * x) # Does not create chainer.FunctionNode
>>> y.backward()
>>> x.grad
array([0., 2., 4., 6., 8.], shape=(5,), dtype=float32, device='native:0')
x (Variable)
x (ndarray)
OpNode
x * x (ndarray)
OpNode
y (ndarray)
y (Variable)
x * x (Variable)
Use ChainerX ndarrays with Chainer.Variable
x (Variable)
x (Array)
x * x (Array)
y (Array)
y (Variable)
x * x (Variable)
x (ArrayNode)
OpNode
x * x (ArrayNode)
OpNode
y (ArrayNode)
x (ArrayBody)
x * x (ArrayBody)
y (ArrayBody)
Use ChainerX ndarrays with Chainer.Variable
x (Variable)
x (ndarray)
x * x (ndarray)
y (ndarray)
y (Variable)
x * x (Variable)
Chainer
(NumPy)
x (Variable)
x (ndarray)
OpNode
x * x (ndarray)
OpNode
y (ndarray)
y (Variable)
x * x (Variable)
Chainer
(ChainerX)
x (ndarray)
OpNode
x * x (ndarray)
OpNode
y (ndarray)
ChainerX
FunctionNode
FunctionNode
Fallback to Chainer
>>> array = chainerx.arange(5, dtype='float32')
>>> x = Variable(array)
>>> y = F.sum(F.square(x)) # forward_chainerx is not defined in F.square
>>> y.backward()
>>> x.grad
array([0., 2., 4., 6., 8.], shape=(5,), dtype=float32, device='native:0')
x (Variable)
x (ndarray)
x * x (ndarray)
OpNode
y (ndarray)
y (Variable)
FunctionNode
x * x (Variable)
OpNode
Fallback to Chainer
x (Variable)
x (ndarray)
x * x (ndarray)
y (ndarray)
y (Variable)
x * x (Variable)
Chainer
(NumPy)
x (Variable)
x (ndarray)
OpNode
x * x (ndarray)
OpNode
y (ndarray)
y (Variable)
x * x (Variable)
Chainer
(ChainerX)
ChainerX
FunctionNode
FunctionNode
FunctionNode