TensorFlow.js

飯塚 修平 @tushuhei

1

TensorFlow Dev Summit 2018 Recap

自己紹介

飯塚 修平 @tushuhei

  • UX エンジニア @Google Brand Studio APAC
  • 博士(工学)
    • ウェブサイト最適化の研究

2

今日話すこと

  • TensorFlow.js 概観
  • 使ってみた
    • Emoji Scavenger Hunt
    • Gboard 物理フリックバージョン

3

なぜブラウザで ML?

  • インストールが必要ない。
  • プライバシーの問題が少ない。
  • スケーラビリティについて悩まなくて済む。
  • デバイスのセンサーに API でアクセスできる。

さまざまなメリットがある。

4

5

  • ブラウザでの学習
  • 訓練済みモデルの移植

6

ブラウザでの学習(Ops API)

7

ブラウザでの学習(Ops API)

8

ブラウザでの学習(Layers API)

9

直接線形代数しない書き方も可能。

訓練済みモデルの移植

10

TensorFlow もしくは Keras で学習したモデルをインポートすることができる。

TensorFlow.js のアプリケーション

11

Emoji Scavenger Hunt

12

13

[1] Howard, Andrew G., et al. "Mobilenets: Efficient convolutional neural networks for mobile vision applications." arXiv preprint arXiv:1704.04861 (2017).

[1]

転移学習

14

224

224

MobileNet

v1_0.25_224

1000

“toilet tissue”

“tench”

・・・

424

“rugby ball”

“kimono”

・・・

全結合層

入力画像

この重みを新たに学習する

既存の重みを使う

モデルサイズと精度の
トレードオフ

推論速度と精度の
トレードオフ

転移学習

1. データの準備

ディレクトリ名をラベルとして、
訓練に使う画像データを集める。

2. 全結合層の学習

訓練プログラムを走らせる。

15

python3 \ tensorflow/examples/image_retraining/retrain.py \

--image_dir /data/images \

--how_many_training_steps=10000 \

--architecture mobilenet_0.25_224 \

--output_labels /data/output_labels.txt \

--saved_model_dir /data/saved_model

images
├── cat
│ ├── cat1.jpg
│ ├── cat2.jpg
│ └── ...
└── dog
│ ├── dog1.jpg
│ ├── dog2.jpg
│ └── ...
...

学習のパイプライン

16

Cloud Storage

Compute Engine

TensorFlow.js

Converter

Developer

訓練画像データ

gsutil sync

SavedModel

Web-friendly

SavedModel

TensorFlow.js Converter

17

python3 -m tensorflowjs.converters.converter \

--input_format=tf_saved_model \

--output_node_names='final_result' \

--saved_model_tags=serve \

/data/saved_model/ \

/data/saved_model_web/

  • web_model.pb
    • グラフ構造
  • weights_manifest.json
    • 重みファイルとの対応
  • group1-shard\*of\*
    • 実際の重み
    • 1 ファイル 4MB 以内になるようシャーディングされる。
      → ブラウザキャッシュの活用

フロントエンド実装の概観

export class MobileNet {

async load() {

this.model = await loadFrozenModel(

'/model/web_model.pb', '/model/weights_manifest.json');

}

predict(input: tfc.Tensor): tfc.Tensor1D {

const reshapedInput = input.reshape([1, ...input.shape]);

const dict: TensorMap = {};

dict[INPUT_NODE_NAME] = reshapedInput;

return this.model.execute(dict, OUTPUT_NODE_NAME) as tfc.Tensor1D;

}

}

18

効率的な GPU メモリ管理

推論を何度も繰り返していると動作が重くなってきた
→ GPU のメモリリークを疑う。

  • dispose
    • tensor や variable の除去
  • tf.tidy
    • 渡された関数内で作られた
      tensor を除去し、メモリ解放する。

19

const result = tf.tidy(() => {

const pixels = tf.fromPixels(videoElement);

return this.MobileNet.predict(pixels);

})

Emoji Scavenger Hunt まとめ

  • MobileNet の転移学習
  • GCP を活用した学習パイプラインの構築
  • TensorFlow.js Converter によるモデルの移植

20

TensorFlow.js のアプリケーション

21

Gboard 物理手書きバージョン

22

Twitter 上での反応

試してみた&精度が良い等、反応も上々。

23

24

データ入力者

訓練データ準備用
ウェブアプリ

訓練
プログラム

G

物理手書き
コンバーター

ユーザ

デバイス

入力

文字

Converter

ブラウザ

ユーザ

入力

文字

訓練済みモデル

(GraphDef 形式)

訓練済みモデル

(SavedModel 形式)

訓練済みモデル
(Web-friendly SavedModel形式)

入力

訓練データ

1. 機械学習モデルの構築

2. ハードウェアでの推論

3. ブラウザでの推論

訓練データの準備

訓練データ準備用のウェブアプリを制作するところからプロジェクト開始。

25

26

27

45,803 records

データ構造

{
"id": 5788999721418752,
"writer": "ffb0dac6b8be3faa81da320e29a2ba72",
"kana": "\u307b",
"events": [
["t", "down", 0],
["g", "down", 40],

... ...
["l", "down", 966],
["l", "up", 1005]
]
}

28

データの前処理

29

[("e", 0), ("d", 55), ("c", 102), ("u", 428), ("j", 507)]

[((2.5, 1), 0), ((3.0, 2), 55), ((3.5, 3), 102),
((6.5, 1), 428), ((7.0, 2), 507)]

[((0.00, 0.00), 0), ((0.11, 0.50), 55), ((0.22, 1.00), 102),
((0.89, 0.00), 428), ((1.00, 0.50), 507)]

QWERTY (JIS) を前提として、キー情報を
(x, y) 座標に変換。

(x, y) 座標化

min-max 正規化

[((0.11, 0.50), 55), ((0.11, 0.50), 47),
((0.67, -1.00), 326), ((0.11, 0.50), 79)]

差分

[((0.11, 0.50), true), ((0.11, 0.50), true),
((0.67, -1.00), false), ((0.11, 0.50), true)]

pen-down エンコード

画像化

入力画像

ドメイン知識の活用

30

方向分解特徴 (directional features) [2]

時間的特徴の利用 (temporal features)

▶ 合計 10 チャンネルを利用
 16px 四方の正方形にレンダリングし、最終的に 16x16x10 の入力サイズになった。

[2] Zhang, Xu-Yao, Yoshua Bengio, and Cheng-Lin Liu. "Online and offline handwritten chinese character recognition: A comprehensive study and new benchmark." Pattern Recognition 61 (2017): 348-360.

ネットワーク構造

31

32

8

8

64

8

8

128

4

4

88

Full Connected

Separable

Convolution

(3x3 kernel, stride: 1, ReLu)

Separable

Convolution

(3x3 kernel, stride: 2, ReLu)

Separable

Convolution

(3x3 kernel, stride: 1, ReLu)

128

4

4

Convolution

(3x3 kernel, stride: 2, ReLu)

128

Average

Pooling

▶ Separable Convolution を使うことによる軽量化: 数 MB → 約 160 KB
読み込みを感じさせない UX を実現できた。

10

16

16

=

Separable Convolution [3]

32

普通の Convolution

分離した Convolution

空間方向

チャンネル方向

掛け算を足し算化することで、パラメータ数を減らす。

[3] L. Sifre. Rigid-motion scattering for image classification. PhD thesis, Ph. D. thesis, 2014.

Images are from [2].

各特徴量による正確度の変化

▶ 方向分解特徴、時間的特徴ともに正確度の向上に効果が認められた。
ドメイン知識の利用によって、サイズを保ったまま正確度を向上できた。

33

ビジュアル表現上の工夫

topK の予測値を各要素の opacity として使う。

keydown event のたびに
推論・更新することで、
アニメーションを実現。

▶ フロントエンド ML
ならでは!

34

top5 probability

[0.8, 0.1, 0.07, 0.02, 0.01]

opacity 値として利用

Gboard 物理手書きバージョン まとめ

  • 訓練データの準備からモデル構築までフルスクラッチ。
    • ウェブアプリを含めて、効率的にデータを準備する仕組みづくり
  • ドメイン知識を用いた特徴量抽出、軽量なモデルアーキテクチャの利用
    → よりよい UX のためには性能をキープした軽量化の工夫が重要。
  • フロントエンド ML による新たなビジュアル表現
    • 高頻度に推論を繰り返せる → 新たな表現の可能性!

35

リンク

技術的解説
Google Developers Japan: Gboard 物理手書きバージョンの舞台裏
https://developers-jp.googleblog.com/2018/04/tegaki.html

訓練データ、学習プログラム、ハードウェア設計図
GitHub - Gboard Physical Handwriting Version
https://github.com/google/mozc-devices/tree/master/mozc-nazoru

36

全体のまとめ

  • TensorFlow.js によってフロントエンド ML が強化された。
  • 学習はサーバサイドでもクライアントサイドでも OK。
  • ブラウザ x ML で広がる新たな表現の可能性!

37

Thank you!

@tushuhei

38

27:21

TensorFlow.js - TensorFlow Dev Summit 2018 Recap - Google Slides