initial commit
This commit is contained in:
parent
4f75e43a39
commit
bbe0b2c96b
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
@ -1,3 +1,7 @@
|
|||||||
# hnsw
|
# hnsw
|
||||||
|
|
||||||
Hierarchical Navigable Small World - demonstration of concept implementation in Python
|
Hierarchical Navigable Small World - demonstration of concept implementation in Python
|
||||||
|
|
||||||
|
Implementation mainly referenced the paper [Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs](https://arxiv.org/abs/1603.09320) however I made simplifications.
|
||||||
|
|
||||||
|
I don't care about performance. That said, we can still compare the relative running time?
|
149
hnsw.py
Normal file
149
hnsw.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from heapq import heappush, heappop
|
||||||
|
from collections import Counter
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def distance(a:np.array, b:np.array):
|
||||||
|
return np.linalg.norm(a-b)
|
||||||
|
|
||||||
|
def printw(W):
|
||||||
|
print('<==')
|
||||||
|
for d,elem in W:
|
||||||
|
print(d, elem.coord)
|
||||||
|
print('==>')
|
||||||
|
|
||||||
|
class HNSWTower:
|
||||||
|
def __init__(self, coord, level):
|
||||||
|
self.coord = coord
|
||||||
|
self.level = level
|
||||||
|
self.neighbors = [set() for _ in range(level+1)]
|
||||||
|
def connect(self, other, level):
|
||||||
|
if other not in self.neighbors[level]:
|
||||||
|
self.neighbors[level].add(other)
|
||||||
|
if self not in other.neighbors:
|
||||||
|
other.neighbors[level].add(self)
|
||||||
|
def disconnect(self, other, level):
|
||||||
|
if other in self.neighbors[level]:
|
||||||
|
self.neighbors[level].remove(other)
|
||||||
|
if self in other.neighbors:
|
||||||
|
other.neighbors[level].remove(self)
|
||||||
|
def limit_degree(self, level, M):
|
||||||
|
if len(self.neighbors[level]) > M:
|
||||||
|
node = None
|
||||||
|
max_dist = 0
|
||||||
|
for other in self.neighbors[level]:
|
||||||
|
d = distance(self.coord, other.coord)
|
||||||
|
if d > max_dist:
|
||||||
|
max_dist = d
|
||||||
|
node = other
|
||||||
|
self.disconnect(node, level)
|
||||||
|
|
||||||
|
|
||||||
|
class HNSW:
|
||||||
|
"""
|
||||||
|
Approximate k-Nearest Neighbor Search (k-ANNS) using Hierarchical Navigable Small World (HNSW) graph
|
||||||
|
"""
|
||||||
|
def __init__(self, mL, M, M_max, ef_construction, seed):
|
||||||
|
# Enter point
|
||||||
|
self.ep = None
|
||||||
|
# Normalizing factor for level generation
|
||||||
|
self.mL = mL
|
||||||
|
# Number of connections to establish
|
||||||
|
self.M = M
|
||||||
|
# Maximum degree for node
|
||||||
|
self.M_max = M_max
|
||||||
|
# Size of NN to return during construction
|
||||||
|
self.ef_construction = ef_construction
|
||||||
|
self.rng = random.Random(seed)
|
||||||
|
self.level_count = Counter()
|
||||||
|
def generate_level(self):
|
||||||
|
# the condition of returning a level higher than L is
|
||||||
|
# self.mL * math.log(1/u) >= L
|
||||||
|
# u <= exp(-L/self.mL) = exp(-1/self.mL) ^ L
|
||||||
|
# equivalent to a "level up probability" of exp(-1/self.mL)
|
||||||
|
u = self.rng.random()
|
||||||
|
return math.floor(-self.mL * math.log(u))
|
||||||
|
def insert_first(self, coord):
|
||||||
|
self.ep = HNSWTower(coord, 0)
|
||||||
|
self.level_count[0] += 1
|
||||||
|
def insert(self, coord):
|
||||||
|
if not self.ep:
|
||||||
|
self.insert_first(coord)
|
||||||
|
return
|
||||||
|
ep = self.ep
|
||||||
|
l = self.generate_level()
|
||||||
|
self.level_count[l] += 1
|
||||||
|
elem = HNSWTower(coord, l)
|
||||||
|
L = self.ep.level
|
||||||
|
for level in range(L, l, -1):
|
||||||
|
ep = self.route_layer(coord, ep, level)
|
||||||
|
# enter points (multiple, sorted by distance from near to far)
|
||||||
|
eps = [(distance(coord, ep.coord), ep)]
|
||||||
|
for level in range(min(l,L), -1, -1):
|
||||||
|
W = self.search_layer(coord, eps, self.ef_construction, level)
|
||||||
|
neighbors = [entry[1] for entry in sorted(W)[:self.M]]
|
||||||
|
# Create node at this layer
|
||||||
|
for other in neighbors:
|
||||||
|
elem.connect(other, level)
|
||||||
|
# Shrink connections if needed
|
||||||
|
for other in neighbors:
|
||||||
|
other.limit_degree(level, self.M_max)
|
||||||
|
eps = W
|
||||||
|
if l > L:
|
||||||
|
self.ep = elem
|
||||||
|
def route_layer(self, coord, ep, level):
|
||||||
|
# returns node that is closes to coord at level, starting from "ep"
|
||||||
|
while True:
|
||||||
|
best = None
|
||||||
|
min_d = distance(coord, ep.coord)
|
||||||
|
for other in ep.neighbors[level]:
|
||||||
|
d = distance(coord, other.coord)
|
||||||
|
if d < min_d:
|
||||||
|
min_d = d
|
||||||
|
best = other
|
||||||
|
if best:
|
||||||
|
ep = best
|
||||||
|
else:
|
||||||
|
return ep
|
||||||
|
def search_layer(self, coord, eps, ef, level):
|
||||||
|
# returns a vector of (distance, node) satisfying heap condition
|
||||||
|
V = set([entry[1] for entry in eps])
|
||||||
|
# set of candidates organized as priority queue (pq)
|
||||||
|
C = eps # nearest element has top priority
|
||||||
|
# dynamic list of found nearest neighbors organized as pq
|
||||||
|
W = [(-entry[0], entry[1]) for entry in eps]
|
||||||
|
# print(level, "start")
|
||||||
|
# printw(W)
|
||||||
|
W.reverse() # farthest element has top priority
|
||||||
|
while len(C) > 0:
|
||||||
|
c_dist, c = heappop(C)
|
||||||
|
f_dist, f = W[0]
|
||||||
|
f_dist = - f_dist
|
||||||
|
if c_dist > f_dist:
|
||||||
|
break # all elements in W are evaluated
|
||||||
|
for e in c.neighbors[level]:
|
||||||
|
if e not in V:
|
||||||
|
V.add(e)
|
||||||
|
f_dist, f = W[0]
|
||||||
|
f_dist = - f_dist
|
||||||
|
e_dist = distance(coord, e.coord)
|
||||||
|
if e_dist < f_dist or len(W) < ef:
|
||||||
|
heappush(C, (e_dist, e))
|
||||||
|
heappush(W, (-e_dist, e))
|
||||||
|
if len(W) > ef:
|
||||||
|
heappop(W)
|
||||||
|
W = [(-entry[0], entry[1]) for entry in W]
|
||||||
|
W.reverse()
|
||||||
|
# print(level, "end")
|
||||||
|
# printw(W)
|
||||||
|
return W
|
||||||
|
def k_nn_search(self, coord, K, ef):
|
||||||
|
ep = self.ep
|
||||||
|
L = ep.level
|
||||||
|
for level in range(L, 0, -1):
|
||||||
|
ep = self.route_layer(coord, ep, level)
|
||||||
|
eps = [(distance(coord, ep.coord), ep)]
|
||||||
|
W = self.search_layer(coord, eps, ef, 0)
|
||||||
|
return sorted(W)[:K]
|
80
knntest.py
Normal file
80
knntest.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from hnsw import *
|
||||||
|
|
||||||
|
class SyntheticVectors:
|
||||||
|
def __init__(self, num_pts, dim, seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
self.num_pts = num_pts
|
||||||
|
self.dim = dim
|
||||||
|
self.points = np.random.normal(0, 1, (num_pts, dim))
|
||||||
|
def knn(self, query, k):
|
||||||
|
W = [(distance(query, self.points[i]), self.points[i]) for i in range(self.num_pts)]
|
||||||
|
return sorted(W)[:k]
|
||||||
|
def validate_results(self, query, k, results):
|
||||||
|
ground_truth = self.knn(query, k)
|
||||||
|
hit = 0
|
||||||
|
for g_dist, g_coord in ground_truth:
|
||||||
|
for r_dist, r_coord in results:
|
||||||
|
if np.isclose(np.linalg.norm(g_coord-r_coord), 0):
|
||||||
|
hit += 1
|
||||||
|
break
|
||||||
|
return hit
|
||||||
|
|
||||||
|
|
||||||
|
class HNSWTester:
|
||||||
|
def __init__(self, num_corpus, dim, seed):
|
||||||
|
self.num_corpus = num_corpus
|
||||||
|
self.dim = dim
|
||||||
|
self.seed = seed
|
||||||
|
self.corpus = SyntheticVectors(num_corpus, dim, seed)
|
||||||
|
np.random.seed(seed+1)
|
||||||
|
def build_hnsw(self, M, ef_construction):
|
||||||
|
self.M, self.ef_construction = M, ef_construction
|
||||||
|
mL = 1/math.log(M)
|
||||||
|
# using the above ml, the "level up" probability is 1/M.
|
||||||
|
# Build the HNSW
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
self.hnsw = HNSW(mL, M, M, ef_construction, self.seed)
|
||||||
|
for ic in range(self.num_corpus):
|
||||||
|
self.hnsw.insert(self.corpus.points[ic])
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
build_time = end_time - start_time
|
||||||
|
print("Built HNSW in %.4f seconds" % build_time)
|
||||||
|
self.build_time = build_time
|
||||||
|
self.print_hnsw_levels()
|
||||||
|
def gen_queries(self, num_query):
|
||||||
|
self.num_query = num_query
|
||||||
|
self.queries = np.random.normal(0, 1, (num_query, self.dim))
|
||||||
|
def query_hnsw(self, K, ef):
|
||||||
|
"returns recall and average query time"
|
||||||
|
self.K, self.ef = K, ef
|
||||||
|
n_hit = 0
|
||||||
|
tot_time = 0
|
||||||
|
for iq in range(self.num_query):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
result = self.hnsw.k_nn_search(self.queries[iq], K, ef)
|
||||||
|
result = [(dist, node.coord) for dist, node in result]
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
tot_time += end_time - start_time
|
||||||
|
n_hit += self.corpus.validate_results(self.queries[iq], K, result)
|
||||||
|
self.recall, self.query_ms = n_hit / (self.num_query * K), tot_time * 1000 / self.num_query
|
||||||
|
print("recall:%.4f, query time/ms: %.4f" % (self.recall, self.query_ms))
|
||||||
|
def print_hnsw_levels(self):
|
||||||
|
cnt = self.hnsw.level_count
|
||||||
|
L = self.hnsw.ep.level
|
||||||
|
for level in range(L+1):
|
||||||
|
print(f"level {level}: {cnt[level]} nodes")
|
||||||
|
def log_results(self, fname="result.csv"):
|
||||||
|
f = open(fname, 'a', encoding='utf-8')
|
||||||
|
f.write(f"{self.num_corpus},{self.dim},")
|
||||||
|
f.write(f"{self.M},{self.ef_construction},%.4f," % self.build_time)
|
||||||
|
f.write(f"{self.K},{self.ef},%.4f,%.4f," % (self.recall, self.query_ms))
|
||||||
|
f.write(f"{self.num_query},{self.seed}\n")
|
||||||
|
|
||||||
|
|
||||||
|
tester = HNSWTester(num_corpus=5000, dim=100, seed=43)
|
||||||
|
tester.build_hnsw(M=10, ef_construction=10)
|
||||||
|
tester.gen_queries(num_query=100)
|
||||||
|
tester.query_hnsw(K=5, ef=10)
|
||||||
|
tester.log_results()
|
21
result.csv
Normal file
21
result.csv
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
N_corpus,Dim,M,ef_construction,build_time,K,ef,recall,query_ms,N_query,seed
|
||||||
|
1000,10,10,50,1.2735,5,10,0.8280,0.4636,100,43
|
||||||
|
1000,20,10,50,1.3208,5,10,0.6200,0.5730,100,43
|
||||||
|
1000,30,10,50,1.2967,5,10,0.6540,0.5083,100,43
|
||||||
|
1000,40,10,50,1.3569,5,10,0.6140,0.4980,100,43
|
||||||
|
1000,50,10,50,1.3414,5,10,0.5340,0.5513,100,43
|
||||||
|
1000,70,10,50,1.3196,5,10,0.5560,0.5107,100,43
|
||||||
|
1000,100,10,50,1.3419,5,10,0.5160,0.5179,100,43
|
||||||
|
1000,100,12,50,1.6164,5,10,0.5760,0.5652,100,43
|
||||||
|
1000,100,14,50,1.8366,5,10,0.6000,0.6434,100,43
|
||||||
|
1000,100,16,50,2.1186,5,10,0.6420,0.6895,100,43
|
||||||
|
1000,100,18,50,2.4118,5,10,0.6880,0.7652,100,43
|
||||||
|
1000,100,20,50,2.7659,5,10,0.6940,0.8060,100,43
|
||||||
|
2000,100,20,50,6.2879,5,10,0.5740,0.9183,100,43
|
||||||
|
3000,100,20,50,10.1791,5,10,0.5280,1.0457,100,43
|
||||||
|
4000,100,20,50,14.1797,5,10,0.4720,1.1173,100,43
|
||||||
|
5000,100,20,50,18.9001,5,10,0.4240,1.1408,100,43
|
||||||
|
5000,100,10,50,9.2855,5,10,0.3220,0.6739,100,43
|
||||||
|
5000,100,10,100,13.6119,5,10,0.3160,0.6867,100,43
|
||||||
|
5000,100,10,20,6.0684,5,10,0.3080,0.6659,100,43
|
||||||
|
5000,100,10,10,4.9333,5,10,0.2520,0.6278,100,43
|
|
Loading…
x
Reference in New Issue
Block a user