1 of 42

Powering large-scale similarity search with deep learning

A practical case study discussing machine learning techniques and scaling challenges to consider.

Jeremy Jordan

Proofpoint, Inc.

2 of 42

Motivating Examples

What products are similar to this?

(Pinterest, Wayfair, Curalate)

What are related legal documents?

(MIT Tech Review, Catalyst, Luminance)

Whose face is this?

(Baidu, Apple FaceID)

3 of 42

Facial Recognition

[0.275, 0.268, 0.037, ..., 0.089]

[0.050, 0.283 , 0. , ..., 0.017]

[0.276, 0.632 , 0.020, ..., 0.782]

[0.016, 2.930 , 0.0197, ..., 0.270]

[0.209, 0.056, 0.033, ..., 0.160]

Known Identities

Submitted Photo

4 of 42

Reference Set

Queries

1

2

3

N-1

N

13.2

22.4

5.9

36.7

29.8

32.7

21.4

8.1

22.9

34.7

4.8

9.7

21.2

26.4

31.8

Calculated Distances (brute force)

5 of 42

Reference Set

Queries

1

2

3

N-1

N

13.2

5.9

21.4

8.1

22.9

4.8

9.7

26.4

Calculated Distances (approximate)

6 of 42

Overview

  • Finding similar objects… how?
    • Deep learning is great at learning convenient representations of objects
      • Images (ie. convnets pretrained on ImageNet)
      • Text (ie. language models)
    • Using metric learning to calibrate your network
  • Approximate nearest neighbors search
    • Why approximate?
    • K-d tree
    • Quantization
    • Product quantization
    • Locally optimized PQ
  • Useful Python libraries

7 of 42

Learning Representations

8 of 42

Learning representations with neural networks

Input Image

Extract low level features

(edges, colors, texture)

Extract mid level features

(simple shapes)

Extract high level features

(complex shapes)

9 of 42

(video slide)

10 of 42

Inspecting vectors obtained from a CNN

[1.1995, 1.1982, 0. , ..., 0. , 0.2172, 1.4938]

[1.7438 , 0.1053 , 0.1708, ..., 0.3663, 0.0168, 0. ]

[0.1314, 0.1715, 0.2535, ..., 0.1383 , 0.0316, 0.4832 ]

Vectors obtained by global average pooling on the last convolutional layer of a ResNet50 model.

11 of 42

Understanding distances in vector space

0

0

0

36.76

36.76

42.46

42.46

46.26

46.26

12 of 42

Calibrating embeddings through distance metric learning

13 of 42

Calibrating embeddings through distance metric learning

Triplet Loss

positive

anchor

negative

14 of 42

Calibrating embeddings through distance metric learning

Embeddings from ResNet18 with ImageNet weights

Embeddings after fine-tuning with a triplet loss

15 of 42

Nearest Neighbor Search

16 of 42

Brute force computation of neighbors

  • O(N) distance computations for first nearest neighbor
  • O(N log k) distance computations for k nearest neighbors

This scales very poorly when:

  • you want to have a large reference set
  • you want to power large numbers of queries

17 of 42

Approximate nearest neighbor methods

  • Trees and forests
  • Neighborhood graphs
  • Locality sensitive hashing
  • Learning to hash
  • Quantization

18 of 42

K-d trees

X0 ≥ 1

X1 ≥ 5

X1 ≥ 0

no

yes

X1 ≥ 2

X1 ≥ 8

X0 ≥ 7

X0 ≥ 3

(-1, 0)

(-2, 1)

(0, 1)

(-2, 4)

(0, 3)

(-3, 2)

(-2, 5)

(-3, 7)

(-2, 6)

(-2, 8)

(-2, 9)

(0, 10)

(3, -1)

(6, -2)

(4, -1)

(9, -4)

(7, -2)

(9, -3)

(1, 3)

(2, 9)

(1, 7)

(4, 2)

(7, 1)

(5, 2)

partitioned vectors

19 of 42

K-d trees

20 of 42

K-d trees

X0 ≥ 1

X1 ≥ 5

X1 ≥ 0

no

yes

X1 ≥ 2

X1 ≥ 8

X0 ≥ 7

X0 ≥ 3

(-1, 0)

(-2, 1)

(0, 1)

(-2, 4)

(0, 3)

(-3, 2)

(-2, 5)

(-3, 7)

(-2, 6)

(-2, 8)

(-2, 9)

(0, 10)

(3, -1)

(6, -2)

(4, -1)

(9, -4)

(7, -2)

(9, -3)

(1, 3)

(2, 9)

(1, 7)

(4, 2)

(7, 1)

(5, 2)

partitioned vectors

Level 4

Level 3

Level 2

Level 1

21 of 42

K-d trees

X0 ≥ 1

X1 ≥ 5

X1 ≥ 0

no

yes

X1 ≥ 2

X1 ≥ 8

X0 ≥ 7

X0 ≥ 3

(-1, 0)

(-2, 1)

(0, 1)

(-2, 4)

(0, 3)

(-3, 2)

(-2, 5)

(-3, 7)

(-2, 6)

(-2, 8)

(-2, 9)

(0, 10)

(3, -1)

(6, -2)

(4, -1)

(9, -4)

(7, -2)

(9, -3)

(1, 3)

(2, 9)

(1, 7)

(4, 2)

(7, 1)

(5, 2)

Query: (4, -2)

22 of 42

K-d trees

X0 ≥ 1

X1 ≥ 0

yes

X0 ≥ 7

X0 ≥ 3

(3, -1)

(6, -2)

(4, -1)

(9, -4)

(7, -2)

(9, -3)

(1, 3)

(2, 9)

(1, 7)

(4, 2)

(7, 1)

(5, 2)

Query: (4, -2)

23 of 42

K-d trees

X0 ≥ 1

X1 ≥ 0

yes

X0 ≥ 7

(3, -1)

(6, -2)

(4, -1)

(9, -4)

(7, -2)

(9, -3)

Query: (4, -2)

24 of 42

K-d trees

X0 ≥ 1

X1 ≥ 0

yes

X0 ≥ 7

(3, -1)

(6, -2)

(4, -1)

Query: (4, -2)

Now we only have to calculate distances against a small subset of vectors

25 of 42

Quantization

quantizer = KMeans(n_clusters=7).fit(X)

Partition vector space according to closest centroid

26 of 42

Quantization

quantization distortion (“error”)

map vector to centroid

27 of 42

quantization distortion (“error”)

map vector to centroid

28 of 42

Quantization

id

centroid

1

1

2

1

3

3

4

7

5

7

400

5

Index

centroid

1

2

3

4

5

6

7

Inverted Index

1

2

9

12

6

13

16

31

4

5

7

14

3

2

22

29

(ids)

We quantize our data by mapping all vectors onto the nearest centroid

29 of 42

Quantization

centroid

1

2

3

4

5

6

7

Inverted Index

1

2

9

12

6

13

16

31

4

5

7

14

3

2

22

29

(ids)

Assuming each vector has an identifier, we can maintain a list of vectors for each Voronoi cell. This type of data structure is known as an inverted index.

1

2

3

4

5

6

7

30 of 42

Quantization

centroid

1

2

3

4

5

6

7

Inverted Index

1

2

9

12

6

13

16

31

4

5

7

14

3

2

22

29

(ids)

For a given query vector…

(shown in orange)

Use k-means predict to find nearest centroid

Look up vectors associated with centroid stored in our inverted index

Performing queries:

31 of 42

Product Quantization

1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11

1.93, 3.21

-8.32, 9.12

-2.11, -5.32

3.82, 1.11

subdivide

quantize

2

5

1

2

quantize

quantize

quantize

kmeans_0.fit(X[:, 0].reshape(-1, 1))

kmeans_1.fit(X[:, 1].reshape(-1, 1))

 

8D example

2D example (easier to visualize)

32 of 42

Product Quantization

(borrowed figure from this talk)

Centroids Needed for Recall@5 = 90%

33 of 42

Product Quantization

coarse_quantizer = KMeans(n_clusters=4).fit(X)

Run product quantization on each coarse cell separately

Subdividing your space for multi-modal data

34 of 42

Product Quantization

Repeat for all coarse cells

Subdividing your space for multi-modal data

35 of 42

Locally optimized product quantization

Locally optimize 🡪 PCA align the product centroids in each coarse cell

36 of 42

Locally optimized product quantization

1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11

1.93, 3.21

-8.32, 9.12

-2.11, -5.32

3.82, 1.11

subdivide

quantize

2

5

1

2

quantize

quantize

quantize

When we do our PCA alignment, use the eigenvalues to allocate dimensions into sub-vectors of equal variance.

37 of 42

Locally optimized product quantization

1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11

3.24, 4.77

-1.20, -5.41

2.74, 4.61

-6.87, 8.10

PCA

quantize

3

1

6

3

quantize

quantize

quantize

When we do our PCA alignment, the resulting dimensions are sorted in terms of decreasing variance.

3.24, -1.20, 2.74, -6.87, 8.10, 4.61, -5.41, 4.77

decreasing variance

Use the eigenvalues to allocate dimensions into sub-vectors of equal variance.

sort into buckets of roughly equal variance

38 of 42

Locally optimized product quantization

(figure from LOPQ paper)

39 of 42

Available libraries for approximate nearest neighbor search

  • Facebook AI Similarity Server (FAISS)
    • State of the art methods, low level documentation
  • Spotify Annoy
    • Less features, developer friendly
  • Non-Metric Space Library (NMSlib)
    • Focus on generality, currently only supports static datasets
  • Scikit-learn Nearest Neighbors
    • KDTree and BallTree classes can be useful for quick prototyping

40 of 42

Visual similarity search architecture

Image

Neural network produces a vector representation of image

0.124, 1.324, 0.532, …, 0.432

0.824, 0.721, …, 0.231

0.254, 0.231, …, 0.582

0.724, 0.834, …, 0.301

0.194, 0.101, …, 0.985

Query for similar vectors using approximate nearest neighbor search

Return visually similar images

0

1

2

N

ANN

41 of 42

Relevant Papers

Metric learning

Approximate nearest neighbors

42 of 42

Thanks!

Questions?