Powering large-scale similarity search with deep learning
A practical case study discussing machine learning techniques and scaling challenges to consider.
Jeremy Jordan
Proofpoint, Inc.
Motivating Examples
What are related legal documents?
Whose face is this?
Facial Recognition
[0.275, 0.268, 0.037, ..., 0.089]
[0.050, 0.283 , 0. , ..., 0.017]
[0.276, 0.632 , 0.020, ..., 0.782]
[0.016, 2.930 , 0.0197, ..., 0.270]
[0.209, 0.056, 0.033, ..., 0.160]
Known Identities
Submitted Photo
Reference Set
Queries
…
1
2
3
N-1
N
…
13.2
22.4
5.9
36.7
29.8
32.7
21.4
8.1
22.9
34.7
4.8
9.7
21.2
26.4
31.8
Calculated Distances (brute force)
…
…
…
Reference Set
Queries
…
1
2
3
N-1
N
…
13.2
5.9
21.4
8.1
22.9
4.8
9.7
26.4
Calculated Distances (approximate)
…
…
…
Overview
Learning Representations
Learning representations with neural networks
Input Image
Extract low level features
(edges, colors, texture)
Extract mid level features
(simple shapes)
Extract high level features
(complex shapes)
(video slide)
Inspecting vectors obtained from a CNN
[1.1995, 1.1982, 0. , ..., 0. , 0.2172, 1.4938]
[1.7438 , 0.1053 , 0.1708, ..., 0.3663, 0.0168, 0. ]
[0.1314, 0.1715, 0.2535, ..., 0.1383 , 0.0316, 0.4832 ]
Vectors obtained by global average pooling on the last convolutional layer of a ResNet50 model.
Understanding distances in vector space
0
0
0
36.76
36.76
42.46
42.46
46.26
46.26
Calibrating embeddings through distance metric learning
Calibrating embeddings through distance metric learning
Triplet Loss
positive
anchor
negative
Calibrating embeddings through distance metric learning
Embeddings from ResNet18 with ImageNet weights
Embeddings after fine-tuning with a triplet loss
Nearest Neighbor Search
Brute force computation of neighbors
This scales very poorly when:
Approximate nearest neighbor methods
K-d trees
X0 ≥ 1
X1 ≥ 5
X1 ≥ 0
no
yes
X1 ≥ 2
X1 ≥ 8
X0 ≥ 7
X0 ≥ 3
(-1, 0)
(-2, 1)
(0, 1)
…
(-2, 4)
(0, 3)
(-3, 2)
…
(-2, 5)
(-3, 7)
(-2, 6)
…
(-2, 8)
(-2, 9)
(0, 10)
…
(3, -1)
(6, -2)
(4, -1)
…
(9, -4)
(7, -2)
(9, -3)
…
(1, 3)
(2, 9)
(1, 7)
…
(4, 2)
(7, 1)
(5, 2)
…
partitioned vectors
K-d trees
K-d trees
X0 ≥ 1
X1 ≥ 5
X1 ≥ 0
no
yes
X1 ≥ 2
X1 ≥ 8
X0 ≥ 7
X0 ≥ 3
(-1, 0)
(-2, 1)
(0, 1)
…
(-2, 4)
(0, 3)
(-3, 2)
…
(-2, 5)
(-3, 7)
(-2, 6)
…
(-2, 8)
(-2, 9)
(0, 10)
…
(3, -1)
(6, -2)
(4, -1)
…
(9, -4)
(7, -2)
(9, -3)
…
(1, 3)
(2, 9)
(1, 7)
…
(4, 2)
(7, 1)
(5, 2)
…
partitioned vectors
Level 4
Level 3
Level 2
Level 1
K-d trees
X0 ≥ 1
X1 ≥ 5
X1 ≥ 0
no
yes
X1 ≥ 2
X1 ≥ 8
X0 ≥ 7
X0 ≥ 3
(-1, 0)
(-2, 1)
(0, 1)
…
(-2, 4)
(0, 3)
(-3, 2)
…
(-2, 5)
(-3, 7)
(-2, 6)
…
(-2, 8)
(-2, 9)
(0, 10)
…
(3, -1)
(6, -2)
(4, -1)
…
(9, -4)
(7, -2)
(9, -3)
…
(1, 3)
(2, 9)
(1, 7)
…
(4, 2)
(7, 1)
(5, 2)
…
Query: (4, -2)
K-d trees
X0 ≥ 1
X1 ≥ 0
yes
X0 ≥ 7
X0 ≥ 3
(3, -1)
(6, -2)
(4, -1)
…
(9, -4)
(7, -2)
(9, -3)
…
(1, 3)
(2, 9)
(1, 7)
…
(4, 2)
(7, 1)
(5, 2)
…
Query: (4, -2)
K-d trees
X0 ≥ 1
X1 ≥ 0
yes
X0 ≥ 7
(3, -1)
(6, -2)
(4, -1)
…
(9, -4)
(7, -2)
(9, -3)
…
Query: (4, -2)
K-d trees
X0 ≥ 1
X1 ≥ 0
yes
X0 ≥ 7
(3, -1)
(6, -2)
(4, -1)
…
Query: (4, -2)
Now we only have to calculate distances against a small subset of vectors
Quantization
quantizer = KMeans(n_clusters=7).fit(X)
Partition vector space according to closest centroid
Quantization
quantization distortion (“error”)
map vector to centroid
quantization distortion (“error”)
map vector to centroid
Quantization
id | centroid |
1 | 1 |
2 | 1 |
3 | 3 |
4 | 7 |
5 | 7 |
… | |
400 | 5 |
Index
centroid |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
Inverted Index
1 | 2 | 9 | 12 | … |
6 | 13 | 16 | 31 | … |
4 | 5 | 7 | 14 | … |
…
3 | 2 | 22 | 29 | … |
(ids)
We quantize our data by mapping all vectors onto the nearest centroid
Quantization
centroid |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
Inverted Index
1 | 2 | 9 | 12 | … |
6 | 13 | 16 | 31 | … |
4 | 5 | 7 | 14 | … |
…
3 | 2 | 22 | 29 | … |
(ids)
Assuming each vector has an identifier, we can maintain a list of vectors for each Voronoi cell. This type of data structure is known as an inverted index.
1
2
3
4
5
6
7
Quantization
centroid |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
Inverted Index
1 | 2 | 9 | 12 | … |
6 | 13 | 16 | 31 | … |
4 | 5 | 7 | 14 | … |
…
3 | 2 | 22 | 29 | … |
(ids)
For a given query vector…
(shown in orange)
Use k-means predict to find nearest centroid
Look up vectors associated with centroid stored in our inverted index
Performing queries:
Product Quantization
1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11
1.93, 3.21
-8.32, 9.12
-2.11, -5.32
3.82, 1.11
subdivide
quantize
2
5
1
2
quantize
quantize
quantize
kmeans_0.fit(X[:, 0].reshape(-1, 1))
kmeans_1.fit(X[:, 1].reshape(-1, 1))
8D example
2D example (easier to visualize)
Product Quantization
(borrowed figure from this talk)
Centroids Needed for Recall@5 = 90%
Product Quantization
coarse_quantizer = KMeans(n_clusters=4).fit(X)
Run product quantization on each coarse cell separately
Subdividing your space for multi-modal data
Product Quantization
Repeat for all coarse cells
Subdividing your space for multi-modal data
Locally optimized product quantization
Locally optimize 🡪 PCA align the product centroids in each coarse cell
Locally optimized product quantization
1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11
1.93, 3.21
-8.32, 9.12
-2.11, -5.32
3.82, 1.11
subdivide
quantize
2
5
1
2
quantize
quantize
quantize
When we do our PCA alignment, use the eigenvalues to allocate dimensions into sub-vectors of equal variance.
Locally optimized product quantization
1.93, 3.21, -8.32, 9.12, -2.11, -5.32, 3.82, 1.11
3.24, 4.77
-1.20, -5.41
2.74, 4.61
-6.87, 8.10
PCA
quantize
3
1
6
3
quantize
quantize
quantize
When we do our PCA alignment, the resulting dimensions are sorted in terms of decreasing variance.
3.24, -1.20, 2.74, -6.87, 8.10, 4.61, -5.41, 4.77
decreasing variance
Use the eigenvalues to allocate dimensions into sub-vectors of equal variance.
sort into buckets of roughly equal variance
Locally optimized product quantization
(figure from LOPQ paper)
Available libraries for approximate nearest neighbor search
Visual similarity search architecture
Image
Neural network produces a vector representation of image
0.124, 1.324, 0.532, …, 0.432
0.824, 0.721, …, 0.231
0.254, 0.231, …, 0.582
0.724, 0.834, …, 0.301
0.194, 0.101, …, 0.985
…
Query for similar vectors using approximate nearest neighbor search
Return visually similar images
0
1
2
N
ANN
Relevant Papers
Thanks!
Questions?